This commit is contained in:
parent
8243e99134
commit
8f815240b3
|
@ -20,6 +20,9 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
@ -58,6 +61,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
|
||||
.flatMap(Set::stream)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private final String dependentVariable;
|
||||
private final BoostedTreeParams boostedTreeParams;
|
||||
private final String predictionFieldName;
|
||||
|
@ -147,9 +156,17 @@ public class Classification implements DataFrameAnalysis {
|
|||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getAllowedCategoricalTypes(String fieldName) {
|
||||
if (dependentVariable.equals(fieldName)) {
|
||||
return ALLOWED_DEPENDENT_VARIABLE_TYPES;
|
||||
}
|
||||
return Types.categorical();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RequiredField> getRequiredFields() {
|
||||
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
|
||||
return Collections.singletonList(new RequiredField(dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
||||
|
||||
|
@ -23,6 +24,12 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
*/
|
||||
boolean supportsCategoricalFields();
|
||||
|
||||
/**
|
||||
* @param fieldName field for which the allowed categorical types should be returned
|
||||
* @return The types treated as categorical for the given field
|
||||
*/
|
||||
Set<String> getAllowedCategoricalTypes(String fieldName);
|
||||
|
||||
/**
|
||||
* @return The names and types of the fields that analyzed documents must have for the analysis to operate
|
||||
*/
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class OutlierDetection implements DataFrameAnalysis {
|
||||
|
||||
|
@ -213,6 +214,11 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getAllowedCategoricalTypes(String fieldName) {
|
||||
return Collections.emptySet();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RequiredField> getRequiredFields() {
|
||||
return Collections.emptyList();
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
@ -134,6 +135,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getAllowedCategoricalTypes(String fieldName) {
|
||||
return Types.categorical();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RequiredField> getRequiredFields() {
|
||||
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.IpFieldMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType;
|
||||
import org.elasticsearch.index.mapper.TextFieldMapper;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Set;
|
||||
|
@ -21,16 +25,22 @@ public final class Types {
|
|||
|
||||
private static final Set<String> CATEGORICAL_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.of("text", "keyword", "ip")
|
||||
Stream.of(TextFieldMapper.CONTENT_TYPE, KeywordFieldMapper.CONTENT_TYPE, IpFieldMapper.CONTENT_TYPE)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private static final Set<String> NUMERICAL_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.concat(
|
||||
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
|
||||
Stream.of("scaled_float"))
|
||||
Stream.concat(Stream.of(NumberType.values()).map(NumberType::typeName), Stream.of("scaled_float"))
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private static final Set<String> DISCRETE_NUMERICAL_TYPES =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.of(NumberType.BYTE, NumberType.SHORT, NumberType.INTEGER, NumberType.LONG)
|
||||
.map(NumberType::typeName)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private static final Set<String> BOOL_TYPES = Collections.singleton(BooleanFieldMapper.CONTENT_TYPE);
|
||||
|
||||
public static Set<String> categorical() {
|
||||
return CATEGORICAL_TYPES;
|
||||
}
|
||||
|
@ -38,4 +48,12 @@ public final class Types {
|
|||
public static Set<String> numerical() {
|
||||
return NUMERICAL_TYPES;
|
||||
}
|
||||
|
||||
public static Set<String> discreteNumerical() {
|
||||
return DISCRETE_NUMERICAL_TYPES;
|
||||
}
|
||||
|
||||
public static Set<String> bool() {
|
||||
return BOOL_TYPES;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException;
|
|||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.get.GetResponse;
|
||||
import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
|
@ -24,6 +25,7 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -34,12 +36,17 @@ import static org.hamcrest.Matchers.hasSize;
|
|||
import static org.hamcrest.Matchers.in;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String BOOLEAN_FIELD = "boolean-field";
|
||||
private static final String NUMERICAL_FIELD = "numerical-field";
|
||||
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
|
||||
private static final String KEYWORD_FIELD = "keyword-field";
|
||||
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
||||
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
|
||||
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
|
||||
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
|
||||
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
|
||||
|
||||
private String jobId;
|
||||
|
@ -53,7 +60,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
|
@ -91,7 +98,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_100");
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
indexData(sourceIndex, 300, 0, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
|
@ -126,17 +133,19 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_50");
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
|
||||
initialize(jobId);
|
||||
indexData(sourceIndex, 300, 0, dependentVariable);
|
||||
|
||||
int numTopClasses = 2;
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
||||
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -151,8 +160,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
String predictedClassField = dependentVariable + "_prediction";
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
|
||||
assertThat(predictedClassValue, is(in(dependentVariableValues)));
|
||||
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
|
@ -161,11 +173,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
} else {
|
||||
nonTrainingRowsCount++;
|
||||
}
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
}
|
||||
assertThat(trainingRowsCount, greaterThan(0));
|
||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
|
@ -178,9 +188,32 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
|
||||
ElasticsearchStatusException e = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
|
||||
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
||||
initialize("classification_top_classes_requested");
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||
|
||||
int numTopClasses = 2;
|
||||
DataFrameAnalyticsConfig config =
|
||||
|
@ -206,7 +239,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertTopClasses(resultsObject, numTopClasses);
|
||||
assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -223,7 +256,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testDependentVariableCardinalityTooHighError() {
|
||||
initialize("cardinality_too_high");
|
||||
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
|
||||
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|
||||
// Index one more document with a class different than the two already used.
|
||||
client().execute(
|
||||
IndexAction.INSTANCE,
|
||||
new IndexRequest(sourceIndex).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).source(KEYWORD_FIELD, "fox")).actionGet();
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
|
@ -240,28 +277,43 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
this.destIndex = sourceIndex + "_results";
|
||||
}
|
||||
|
||||
private static void indexData(String sourceIndex,
|
||||
int numTrainingRows, int numNonTrainingRows,
|
||||
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
|
||||
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
|
||||
.addMapping("_doc",
|
||||
BOOLEAN_FIELD, "type=boolean",
|
||||
NUMERICAL_FIELD, "type=double",
|
||||
DISCRETE_NUMERICAL_FIELD, "type=integer",
|
||||
KEYWORD_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < numTrainingRows; i++) {
|
||||
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
|
||||
List<Object> source = Arrays.asList(
|
||||
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
|
||||
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
|
||||
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
|
||||
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FIELD, numericalValue);
|
||||
List<Object> source = new ArrayList<>();
|
||||
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(Arrays.asList(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
|
||||
}
|
||||
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(Arrays.asList(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
|
||||
}
|
||||
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(
|
||||
Arrays.asList(
|
||||
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
|
||||
}
|
||||
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
|
||||
source.addAll(Arrays.asList(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
|
||||
}
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
|
@ -289,7 +341,12 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
return resultsObject;
|
||||
}
|
||||
|
||||
private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) {
|
||||
private static <T> void assertTopClasses(
|
||||
Map<String, Object> resultsObject,
|
||||
int numTopClasses,
|
||||
String dependentVariable,
|
||||
List<T> dependentVariableValues,
|
||||
Function<String, T> parser) {
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
|
||||
|
@ -302,9 +359,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||
}
|
||||
// Assert that all the predicted class names come from the set of keyword field values.
|
||||
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
|
||||
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
|
||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
|
||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
// Assert that the top classes are listed in the order of decreasing probabilities.
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.elasticsearch.search.SearchHit;
|
|||
import org.elasticsearch.search.fetch.StoredFieldsContext;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex;
|
||||
|
||||
|
@ -31,7 +31,6 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
|
@ -265,15 +264,11 @@ public class DataFrameDataExtractor {
|
|||
.setTrackTotalHits(true);
|
||||
}
|
||||
|
||||
public Set<String> getCategoricalFields() {
|
||||
Set<String> categoricalFields = new HashSet<>();
|
||||
for (ExtractedField extractedField : context.extractedFields.getAllFields()) {
|
||||
String fieldName = extractedField.getName();
|
||||
if (Types.categorical().containsAll(extractedField.getTypes())) {
|
||||
categoricalFields.add(fieldName);
|
||||
}
|
||||
}
|
||||
return categoricalFields;
|
||||
public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
|
||||
return context.extractedFields.getAllFields().stream()
|
||||
.filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()).containsAll(extractedField.getTypes()))
|
||||
.map(ExtractedField::getName)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public static class DataSummary {
|
||||
|
|
|
@ -264,7 +264,14 @@ public class ExtractedFieldsDetector {
|
|||
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
|
||||
for (ExtractedField field : extractedFields.getAllFields()) {
|
||||
if (isBoolean(field.getTypes())) {
|
||||
adjusted.add(new BooleanAsInteger(field));
|
||||
if (config.getAnalysis().getAllowedCategoricalTypes(field.getAlias()).contains(BooleanFieldMapper.CONTENT_TYPE)) {
|
||||
// We convert boolean field to string if it is a categorical dependent variable
|
||||
adjusted.add(new BooleanMapper<>(field, Boolean.TRUE.toString(), Boolean.FALSE.toString()));
|
||||
} else {
|
||||
// We convert boolean fields to integers with values 0, 1 as this is the preferred
|
||||
// way to consume such features in the analytics process.
|
||||
adjusted.add(new BooleanMapper<>(field, 1, 0));
|
||||
}
|
||||
} else {
|
||||
adjusted.add(field);
|
||||
}
|
||||
|
@ -277,21 +284,24 @@ public class ExtractedFieldsDetector {
|
|||
}
|
||||
|
||||
/**
|
||||
* We convert boolean fields to integers with values 0, 1 as this is the preferred
|
||||
* way to consume such features in the analytics process.
|
||||
* {@link BooleanMapper} makes boolean field behave as a field of different type.
|
||||
*/
|
||||
private static class BooleanAsInteger extends ExtractedField {
|
||||
private static final class BooleanMapper<T> extends ExtractedField {
|
||||
|
||||
protected BooleanAsInteger(ExtractedField field) {
|
||||
private final T trueValue;
|
||||
private final T falseValue;
|
||||
|
||||
BooleanMapper(ExtractedField field, T trueValue, T falseValue) {
|
||||
super(field.getAlias(), field.getName(), Collections.singleton(BooleanFieldMapper.CONTENT_TYPE), ExtractionMethod.DOC_VALUE);
|
||||
this.trueValue = trueValue;
|
||||
this.falseValue = falseValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object[] value(SearchHit hit) {
|
||||
DocumentField keyValue = hit.field(name);
|
||||
if (keyValue != null) {
|
||||
List<Object> values = keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? 1 : 0).collect(Collectors.toList());
|
||||
return values.toArray(new Object[0]);
|
||||
return keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? trueValue : falseValue).toArray();
|
||||
}
|
||||
return new Object[0];
|
||||
}
|
||||
|
|
|
@ -362,7 +362,7 @@ public class AnalyticsProcessManager {
|
|||
|
||||
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
|
||||
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
|
||||
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols,
|
||||
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
|
||||
return processConfig;
|
||||
|
|
|
@ -57,7 +57,7 @@ public class MemoryUsageEstimationProcessManager {
|
|||
DataFrameDataExtractorFactory dataExtractorFactory) {
|
||||
DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false);
|
||||
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
|
||||
if (dataSummary.rows == 0) {
|
||||
return new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO);
|
||||
}
|
||||
|
|
|
@ -24,6 +24,9 @@ import org.elasticsearch.search.SearchHit;
|
|||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
@ -41,7 +44,9 @@ import java.util.Optional;
|
|||
import java.util.Queue;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
@ -384,6 +389,36 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
assertThat(dataExtractor.hasNext(), is(false));
|
||||
}
|
||||
|
||||
public void testGetCategoricalFields() {
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
ExtractedField.newField("field_boolean", Collections.singleton("boolean"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_float", Collections.singleton("float"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_double", Collections.singleton("double"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_byte", Collections.singleton("byte"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_short", Collections.singleton("short"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_integer", Collections.singleton("integer"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_long", Collections.singleton("long"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_keyword", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE),
|
||||
ExtractedField.newField("field_text", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE)));
|
||||
TestExtractor dataExtractor = createExtractor(true, true);
|
||||
|
||||
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
|
||||
|
||||
assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")), containsInAnyOrder("field_keyword", "field_text"));
|
||||
assertThat(dataExtractor.getCategoricalFields(new Regression("field_long")), containsInAnyOrder("field_keyword", "field_text"));
|
||||
assertThat(dataExtractor.getCategoricalFields(new Regression("field_boolean")), containsInAnyOrder("field_keyword", "field_text"));
|
||||
|
||||
assertThat(
|
||||
dataExtractor.getCategoricalFields(new Classification("field_keyword")),
|
||||
containsInAnyOrder("field_keyword", "field_text"));
|
||||
assertThat(
|
||||
dataExtractor.getCategoricalFields(new Classification("field_long")),
|
||||
containsInAnyOrder("field_keyword", "field_text", "field_long"));
|
||||
assertThat(
|
||||
dataExtractor.getCategoricalFields(new Classification("field_boolean")),
|
||||
containsInAnyOrder("field_keyword", "field_text", "field_boolean"));
|
||||
}
|
||||
|
||||
private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) {
|
||||
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
|
||||
JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues);
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
|
||||
|
@ -213,6 +214,22 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
"expected types are [byte, double, float, half_float, integer, long, scaled_float, short]"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenClassificationAndRequiredFieldHasInvalidType() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_float", "float")
|
||||
.addAggregatableField("some_long", "long")
|
||||
.addAggregatableField("some_keyword", "keyword")
|
||||
.addAggregatableField("foo", "keyword")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildClassificationConfig("some_float"), RESULTS_FIELD, false, 100, fieldCapabilities);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("invalid types [float] for required field [some_float]; " +
|
||||
"expected types are [boolean, byte, integer, ip, keyword, long, short, text]"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenIgnoredField() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("_id", "float").build();
|
||||
|
@ -467,7 +484,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
|
||||
}
|
||||
|
||||
public void testDetect_GivenBooleanField() {
|
||||
public void testDetect_GivenBooleanField_BooleanMappedAsInteger() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_boolean", "boolean")
|
||||
.build();
|
||||
|
@ -483,19 +500,38 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
|
||||
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
|
||||
Object[] values = booleanField.value(hit);
|
||||
assertThat(values.length, equalTo(1));
|
||||
assertThat(values[0], equalTo(1));
|
||||
assertThat(booleanField.value(hit), arrayContaining(1));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
|
||||
values = booleanField.value(hit);
|
||||
assertThat(values.length, equalTo(1));
|
||||
assertThat(values[0], equalTo(0));
|
||||
assertThat(booleanField.value(hit), arrayContaining(0));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
|
||||
values = booleanField.value(hit);
|
||||
assertThat(values.length, equalTo(3));
|
||||
assertThat(values, arrayContaining(0, 1, 0));
|
||||
assertThat(booleanField.value(hit), arrayContaining(0, 1, 0));
|
||||
}
|
||||
|
||||
public void testDetect_GivenBooleanField_BooleanMappedAsString() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_boolean", "boolean")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildClassificationConfig("some_boolean"), RESULTS_FIELD, false, 100, fieldCapabilities);
|
||||
ExtractedFields extractedFields = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = extractedFields.getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
ExtractedField booleanField = allFields.get(0);
|
||||
assertThat(booleanField.getTypes(), contains("boolean"));
|
||||
assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
|
||||
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("true"));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("false"));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
|
||||
|
@ -526,6 +562,15 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
}
|
||||
|
||||
private static DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) {
|
||||
return new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("foo")
|
||||
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
|
||||
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
|
||||
.setAnalysis(new Classification(dependentVariable))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static class MockFieldCapsResponseBuilder {
|
||||
|
||||
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
|
||||
|
|
Loading…
Reference in New Issue