[7.x] Allow integer types for classification's dependent variable (#47902) (#48080)

This commit is contained in:
Przemysław Witek 2019-10-16 11:09:56 +02:00 committed by GitHub
parent 8243e99134
commit 8f815240b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 262 additions and 66 deletions

View File

@ -20,6 +20,9 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -58,6 +61,12 @@ public class Classification implements DataFrameAnalysis {
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
} }
private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES =
Collections.unmodifiableSet(
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
.flatMap(Set::stream)
.collect(Collectors.toSet()));
private final String dependentVariable; private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams; private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName; private final String predictionFieldName;
@ -147,9 +156,17 @@ public class Classification implements DataFrameAnalysis {
return true; return true;
} }
@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
if (dependentVariable.equals(fieldName)) {
return ALLOWED_DEPENDENT_VARIABLE_TYPES;
}
return Types.categorical();
}
@Override @Override
public List<RequiredField> getRequiredFields() { public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical())); return Collections.singletonList(new RequiredField(dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
} }
@Override @Override

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
@ -23,6 +24,12 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/ */
boolean supportsCategoricalFields(); boolean supportsCategoricalFields();
/**
* @param fieldName field for which the allowed categorical types should be returned
* @return The types treated as categorical for the given field
*/
Set<String> getAllowedCategoricalTypes(String fieldName);
/** /**
* @return The names and types of the fields that analyzed documents must have for the analysis to operate * @return The names and types of the fields that analyzed documents must have for the analysis to operate
*/ */

View File

@ -22,6 +22,7 @@ import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
public class OutlierDetection implements DataFrameAnalysis { public class OutlierDetection implements DataFrameAnalysis {
@ -213,6 +214,11 @@ public class OutlierDetection implements DataFrameAnalysis {
return false; return false;
} }
@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
return Collections.emptySet();
}
@Override @Override
public List<RequiredField> getRequiredFields() { public List<RequiredField> getRequiredFields() {
return Collections.emptyList(); return Collections.emptyList();

View File

@ -20,6 +20,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -134,6 +135,11 @@ public class Regression implements DataFrameAnalysis {
return true; return true;
} }
@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
return Types.categorical();
}
@Override @Override
public List<RequiredField> getRequiredFields() { public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical())); return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));

View File

@ -5,7 +5,11 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.analyses; package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.IpFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType;
import org.elasticsearch.index.mapper.TextFieldMapper;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
@ -21,16 +25,22 @@ public final class Types {
private static final Set<String> CATEGORICAL_TYPES = private static final Set<String> CATEGORICAL_TYPES =
Collections.unmodifiableSet( Collections.unmodifiableSet(
Stream.of("text", "keyword", "ip") Stream.of(TextFieldMapper.CONTENT_TYPE, KeywordFieldMapper.CONTENT_TYPE, IpFieldMapper.CONTENT_TYPE)
.collect(Collectors.toSet())); .collect(Collectors.toSet()));
private static final Set<String> NUMERICAL_TYPES = private static final Set<String> NUMERICAL_TYPES =
Collections.unmodifiableSet( Collections.unmodifiableSet(
Stream.concat( Stream.concat(Stream.of(NumberType.values()).map(NumberType::typeName), Stream.of("scaled_float"))
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
Stream.of("scaled_float"))
.collect(Collectors.toSet())); .collect(Collectors.toSet()));
private static final Set<String> DISCRETE_NUMERICAL_TYPES =
Collections.unmodifiableSet(
Stream.of(NumberType.BYTE, NumberType.SHORT, NumberType.INTEGER, NumberType.LONG)
.map(NumberType::typeName)
.collect(Collectors.toSet()));
private static final Set<String> BOOL_TYPES = Collections.singleton(BooleanFieldMapper.CONTENT_TYPE);
public static Set<String> categorical() { public static Set<String> categorical() {
return CATEGORICAL_TYPES; return CATEGORICAL_TYPES;
} }
@ -38,4 +48,12 @@ public final class Types {
public static Set<String> numerical() { public static Set<String> numerical() {
return NUMERICAL_TYPES; return NUMERICAL_TYPES;
} }
public static Set<String> discreteNumerical() {
return DISCRETE_NUMERICAL_TYPES;
}
public static Set<String> bool() {
return BOOL_TYPES;
}
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
@ -24,6 +25,7 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -34,12 +36,17 @@ import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.startsWith;
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String BOOLEAN_FIELD = "boolean-field";
private static final String NUMERICAL_FIELD = "numerical-field"; private static final String NUMERICAL_FIELD = "numerical-field";
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
private static final String KEYWORD_FIELD = "keyword-field"; private static final String KEYWORD_FIELD = "keyword-field";
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0)); private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat")); private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
private String jobId; private String jobId;
@ -53,7 +60,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set"); initialize("classification_single_numeric_feature_and_mixed_data_set");
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config); registerAnalytics(config);
@ -91,7 +98,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_100"); initialize("classification_only_training_data_and_training_percent_is_100");
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); indexData(sourceIndex, 300, 0, KEYWORD_FIELD);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config); registerAnalytics(config);
@ -126,17 +133,19 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis"); "Finished analysis");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
initialize("classification_only_training_data_and_training_percent_is_50"); String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); initialize(jobId);
indexData(sourceIndex, 300, 0, dependentVariable);
int numTopClasses = 2;
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig config =
buildAnalytics( buildAnalytics(
jobId, jobId,
sourceIndex, sourceIndex,
destIndex, destIndex,
null, null,
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0)); new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
registerAnalytics(config); registerAnalytics(config);
putAnalytics(config); putAnalytics(config);
@ -151,8 +160,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) { for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); String predictedClassField = dependentVariable + "_prediction";
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey(predictedClassField), is(true));
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
assertThat(predictedClassValue, is(in(dependentVariableValues)));
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);
assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results // Let's just assert there's both training and non-training results
@ -161,11 +173,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
} else { } else {
nonTrainingRowsCount++; nonTrainingRowsCount++;
} }
assertThat(resultsObject.containsKey("top_classes"), is(false));
} }
assertThat(trainingRowsCount, greaterThan(0)); assertThat(trainingRowsCount, greaterThan(0));
assertThat(nonTrainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0));
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
@ -178,9 +188,32 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis"); "Finished analysis");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class,
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
}
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
initialize("classification_top_classes_requested"); initialize("classification_top_classes_requested");
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
int numTopClasses = 2; int numTopClasses = 2;
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig config =
@ -206,7 +239,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertTopClasses(resultsObject, numTopClasses); assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
} }
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
@ -223,7 +256,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testDependentVariableCardinalityTooHighError() { public void testDependentVariableCardinalityTooHighError() {
initialize("cardinality_too_high"); initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox")); indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used.
client().execute(
IndexAction.INSTANCE,
new IndexRequest(sourceIndex).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).source(KEYWORD_FIELD, "fox")).actionGet();
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config); registerAnalytics(config);
@ -240,28 +277,43 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
this.destIndex = sourceIndex + "_results"; this.destIndex = sourceIndex + "_results";
} }
private static void indexData(String sourceIndex, private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
int numTrainingRows, int numNonTrainingRows,
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
client().admin().indices().prepareCreate(sourceIndex) client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword") .addMapping("_doc",
BOOLEAN_FIELD, "type=boolean",
NUMERICAL_FIELD, "type=double",
DISCRETE_NUMERICAL_FIELD, "type=integer",
KEYWORD_FIELD, "type=keyword")
.get(); .get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numTrainingRows; i++) { for (int i = 0; i < numTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); List<Object> source = Arrays.asList(
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size()); BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
IndexRequest indexRequest = new IndexRequest(sourceIndex) DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue); KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest); bulkRequestBuilder.add(indexRequest);
} }
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); List<Object> source = new ArrayList<>();
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
IndexRequest indexRequest = new IndexRequest(sourceIndex) source.addAll(Arrays.asList(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
.source(NUMERICAL_FIELD, numericalValue); }
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(Arrays.asList(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
}
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(
Arrays.asList(
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
}
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(Arrays.asList(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest); bulkRequestBuilder.add(indexRequest);
} }
BulkResponse bulkResponse = bulkRequestBuilder.get(); BulkResponse bulkResponse = bulkRequestBuilder.get();
@ -289,7 +341,12 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
return resultsObject; return resultsObject;
} }
private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) { private static <T> void assertTopClasses(
Map<String, Object> resultsObject,
int numTopClasses,
String dependentVariable,
List<T> dependentVariableValues,
Function<String, T> parser) {
assertThat(resultsObject.containsKey("top_classes"), is(true)); assertThat(resultsObject.containsKey("top_classes"), is(true));
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes"); List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
@ -302,9 +359,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
classProbabilities.add((Double) topClass.get("class_probability")); classProbabilities.add((Double) topClass.get("class_probability"));
} }
// Assert that all the predicted class names come from the set of keyword field values. // Assert that all the predicted class names come from the set of keyword field values.
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES)))); classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
// Assert that the first class listed in top classes is the same as the predicted class. // Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction"))); assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval. // Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities. // Assert that the top classes are listed in the order of decreasing probabilities.

View File

@ -23,7 +23,7 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;
@ -31,7 +31,6 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@ -265,15 +264,11 @@ public class DataFrameDataExtractor {
.setTrackTotalHits(true); .setTrackTotalHits(true);
} }
public Set<String> getCategoricalFields() { public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
Set<String> categoricalFields = new HashSet<>(); return context.extractedFields.getAllFields().stream()
for (ExtractedField extractedField : context.extractedFields.getAllFields()) { .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()).containsAll(extractedField.getTypes()))
String fieldName = extractedField.getName(); .map(ExtractedField::getName)
if (Types.categorical().containsAll(extractedField.getTypes())) { .collect(Collectors.toSet());
categoricalFields.add(fieldName);
}
}
return categoricalFields;
} }
public static class DataSummary { public static class DataSummary {

View File

@ -264,7 +264,14 @@ public class ExtractedFieldsDetector {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size()); List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getAllFields()) { for (ExtractedField field : extractedFields.getAllFields()) {
if (isBoolean(field.getTypes())) { if (isBoolean(field.getTypes())) {
adjusted.add(new BooleanAsInteger(field)); if (config.getAnalysis().getAllowedCategoricalTypes(field.getAlias()).contains(BooleanFieldMapper.CONTENT_TYPE)) {
// We convert boolean field to string if it is a categorical dependent variable
adjusted.add(new BooleanMapper<>(field, Boolean.TRUE.toString(), Boolean.FALSE.toString()));
} else {
// We convert boolean fields to integers with values 0, 1 as this is the preferred
// way to consume such features in the analytics process.
adjusted.add(new BooleanMapper<>(field, 1, 0));
}
} else { } else {
adjusted.add(field); adjusted.add(field);
} }
@ -277,21 +284,24 @@ public class ExtractedFieldsDetector {
} }
/** /**
* We convert boolean fields to integers with values 0, 1 as this is the preferred * {@link BooleanMapper} makes boolean field behave as a field of different type.
* way to consume such features in the analytics process.
*/ */
private static class BooleanAsInteger extends ExtractedField { private static final class BooleanMapper<T> extends ExtractedField {
protected BooleanAsInteger(ExtractedField field) { private final T trueValue;
private final T falseValue;
BooleanMapper(ExtractedField field, T trueValue, T falseValue) {
super(field.getAlias(), field.getName(), Collections.singleton(BooleanFieldMapper.CONTENT_TYPE), ExtractionMethod.DOC_VALUE); super(field.getAlias(), field.getName(), Collections.singleton(BooleanFieldMapper.CONTENT_TYPE), ExtractionMethod.DOC_VALUE);
this.trueValue = trueValue;
this.falseValue = falseValue;
} }
@Override @Override
public Object[] value(SearchHit hit) { public Object[] value(SearchHit hit) {
DocumentField keyValue = hit.field(name); DocumentField keyValue = hit.field(name);
if (keyValue != null) { if (keyValue != null) {
List<Object> values = keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? 1 : 0).collect(Collectors.toList()); return keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? trueValue : falseValue).toArray();
return values.toArray(new Object[0]);
} }
return new Object[0]; return new Object[0];
} }

View File

@ -362,7 +362,7 @@ public class AnalyticsProcessManager {
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields(); Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols, AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols,
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis()); config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
return processConfig; return processConfig;

View File

@ -57,7 +57,7 @@ public class MemoryUsageEstimationProcessManager {
DataFrameDataExtractorFactory dataExtractorFactory) { DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false);
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields(); Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
if (dataSummary.rows == 0) { if (dataSummary.rows == 0) {
return new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO); return new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO);
} }

View File

@ -24,6 +24,9 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
@ -41,7 +44,9 @@ import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
@ -384,6 +389,36 @@ public class DataFrameDataExtractorTests extends ESTestCase {
assertThat(dataExtractor.hasNext(), is(false)); assertThat(dataExtractor.hasNext(), is(false));
} }
public void testGetCategoricalFields() {
extractedFields = new ExtractedFields(Arrays.asList(
ExtractedField.newField("field_boolean", Collections.singleton("boolean"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_float", Collections.singleton("float"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_double", Collections.singleton("double"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_byte", Collections.singleton("byte"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_short", Collections.singleton("short"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_integer", Collections.singleton("integer"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_long", Collections.singleton("long"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_keyword", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
ExtractedField.newField("field_text", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE)));
TestExtractor dataExtractor = createExtractor(true, true);
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")), containsInAnyOrder("field_keyword", "field_text"));
assertThat(dataExtractor.getCategoricalFields(new Regression("field_long")), containsInAnyOrder("field_keyword", "field_text"));
assertThat(dataExtractor.getCategoricalFields(new Regression("field_boolean")), containsInAnyOrder("field_keyword", "field_text"));
assertThat(
dataExtractor.getCategoricalFields(new Classification("field_keyword")),
containsInAnyOrder("field_keyword", "field_text"));
assertThat(
dataExtractor.getCategoricalFields(new Classification("field_long")),
containsInAnyOrder("field_keyword", "field_text", "field_long"));
assertThat(
dataExtractor.getCategoricalFields(new Classification("field_boolean")),
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
}
private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) { private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues); JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues);

View File

@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
@ -213,6 +214,22 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
"expected types are [byte, double, float, half_float, integer, long, scaled_float, short]")); "expected types are [byte, double, float, half_float, integer, long, scaled_float, short]"));
} }
public void testDetect_GivenClassificationAndRequiredFieldHasInvalidType() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_float", "float")
.addAggregatableField("some_long", "long")
.addAggregatableField("some_keyword", "keyword")
.addAggregatableField("foo", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("some_float"), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("invalid types [float] for required field [some_float]; " +
"expected types are [boolean, byte, integer, ip, keyword, long, short, text]"));
}
public void testDetect_GivenIgnoredField() { public void testDetect_GivenIgnoredField() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("_id", "float").build(); .addAggregatableField("_id", "float").build();
@ -467,7 +484,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
} }
public void testDetect_GivenBooleanField() { public void testDetect_GivenBooleanField_BooleanMappedAsInteger() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_boolean", "boolean") .addAggregatableField("some_boolean", "boolean")
.build(); .build();
@ -483,19 +500,38 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)); assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build(); SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
Object[] values = booleanField.value(hit); assertThat(booleanField.value(hit), arrayContaining(1));
assertThat(values.length, equalTo(1));
assertThat(values[0], equalTo(1));
hit = new SearchHitBuilder(42).addField("some_boolean", false).build(); hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
values = booleanField.value(hit); assertThat(booleanField.value(hit), arrayContaining(0));
assertThat(values.length, equalTo(1));
assertThat(values[0], equalTo(0));
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build(); hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
values = booleanField.value(hit); assertThat(booleanField.value(hit), arrayContaining(0, 1, 0));
assertThat(values.length, equalTo(3)); }
assertThat(values, arrayContaining(0, 1, 0));
public void testDetect_GivenBooleanField_BooleanMappedAsString() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_boolean", "boolean")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("some_boolean"), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
assertThat(allFields.size(), equalTo(1));
ExtractedField booleanField = allFields.get(0);
assertThat(booleanField.getTypes(), contains("boolean"));
assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
assertThat(booleanField.value(hit), arrayContaining("true"));
hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
assertThat(booleanField.value(hit), arrayContaining("false"));
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
} }
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
@ -526,6 +562,15 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
} }
private static DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
return new DataFrameAnalyticsConfig.Builder()
.setId("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
.setAnalysis(new Classification(dependentVariable))
.build();
}
private static class MockFieldCapsResponseBuilder { private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>(); private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();