From 5921ae53d807c387ebb1ff60dd40c589c50f2bfe Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 30 Aug 2019 09:57:43 +0300 Subject: [PATCH] [7.x][ML] Regression dependent variable must be numeric (#46072) (#46136) * [ML] Regression dependent variable must be numeric This adds a validation that the dependent variable of a regression analysis must be numeric. * Address review comments and fix some problems In addition to addressing the review comments, this commit fixes a few issues I found during testing. In particular: - if there were mappings for required fields but they were not included we were not reporting the error - if explicitly included fields had unsupported types we were not reporting the error Unfortunately, I couldn't get those fixed without refactoring the code in `ExtractedFieldsDetector`. --- .../dataframe/analyses/DataFrameAnalysis.java | 6 +- .../dataframe/analyses/OutlierDetection.java | 6 +- .../ml/dataframe/analyses/Regression.java | 6 +- .../ml/dataframe/analyses/RequiredField.java | 36 ++++ .../core/ml/dataframe/analyses/Types.java | 43 +++++ .../extractor/DataFrameDataExtractor.java | 3 +- .../extractor/ExtractedFieldsDetector.java | 156 +++++++++++------- .../ExtractedFieldsDetectorTests.java | 150 +++++++++++++++-- 8 files changed, 326 insertions(+), 80 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RequiredField.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 0ea15b6f803..bc0e623cdeb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; +import java.util.List; import java.util.Map; -import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { @@ -24,9 +24,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { boolean supportsCategoricalFields(); /** - * @return The set of fields that analyzed documents must have for the analysis to operate + * @return The names and types of the fields that analyzed documents must have for the analysis to operate */ - Set getRequiredFields(); + List getRequiredFields(); /** * @return {@code true} if this analysis supports data frame rows with missing values diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 32a47890572..35e3d234a7c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -18,10 +18,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.Set; public class OutlierDetection implements DataFrameAnalysis { @@ -160,8 +160,8 @@ public class OutlierDetection implements DataFrameAnalysis { } @Override - public Set getRequiredFields() { - return Collections.emptySet(); + public List getRequiredFields() { + return Collections.emptyList(); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 966a67c22c0..04a5801ffa2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -17,9 +17,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; public class Regression implements DataFrameAnalysis { @@ -201,8 +201,8 @@ public class Regression implements DataFrameAnalysis { } @Override - public Set getRequiredFields() { - return Collections.singleton(dependentVariable); + public List getRequiredFields() { + return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical())); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RequiredField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RequiredField.java new file mode 100644 index 00000000000..bca96b1a1b1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RequiredField.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +public class RequiredField { + + private final String name; + + /** + * The required field must have one of those types. + * We use a sorted set to ensure types are reported alphabetically in error messages. + */ + private final SortedSet types; + + public RequiredField(String name, Set types) { + this.name = Objects.requireNonNull(name); + this.types = Collections.unmodifiableSortedSet(new TreeSet<>(types)); + } + + public String getName() { + return name; + } + + public SortedSet getTypes() { + return types; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java new file mode 100644 index 00000000000..ba7cac81d7f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.index.mapper.NumberFieldMapper; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Helper class that defines groups of types + */ +public final class Types { + + private Types() {} + + private static final Set CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip"))); + + private static final Set NUMERICAL_TYPES; + + static { + Set numericalTypes = Stream.of(NumberFieldMapper.NumberType.values()) + .map(NumberFieldMapper.NumberType::typeName) + .collect(Collectors.toSet()); + numericalTypes.add("scaled_float"); + NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes); + } + + public static Set categorical() { + return CATEGORICAL_TYPES; + } + + public static Set numerical() { + return NUMERICAL_TYPES; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 738c5814361..657608d08bb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex; @@ -268,7 +269,7 @@ public class DataFrameDataExtractor { Set categoricalFields = new HashSet<>(); for (ExtractedField extractedField : context.extractedFields.getAllFields()) { String fieldName = extractedField.getName(); - if (ExtractedFieldsDetector.CATEGORICAL_TYPES.containsAll(extractedField.getTypes())) { + if (Types.categorical().containsAll(extractedField.getTypes())) { categoricalFields.add(fieldName); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index df0b8011868..dc173f4d8ff 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -15,11 +15,12 @@ import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.BooleanFieldMapper; -import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NameResolver; @@ -35,10 +36,10 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; -import java.util.stream.Stream; public class ExtractedFieldsDetector { @@ -50,18 +51,6 @@ public class ExtractedFieldsDetector { private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", "_source", "_type", "_uid", "_version", "_feature", "_ignored", DataFrameAnalyticsIndex.ID_COPY); - public static final Set CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip"))); - - private static final Set NUMERICAL_TYPES; - - static { - Set numericalTypes = Stream.of(NumberFieldMapper.NumberType.values()) - .map(NumberFieldMapper.NumberType::typeName) - .collect(Collectors.toSet()); - numericalTypes.add("scaled_float"); - NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes); - } - private final String[] index; private final DataFrameAnalyticsConfig config; private final String resultsField; @@ -80,12 +69,7 @@ public class ExtractedFieldsDetector { } public ExtractedFields detect() { - Set fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet()); - fields.removeAll(IGNORE_FIELDS); - removeFieldsUnderResultsField(fields); - includeAndExcludeFields(fields); - removeFieldsWithIncompatibleTypes(fields); - checkRequiredFieldsArePresent(fields); + Set fields = getIncludedFields(); if (fields.isEmpty()) { throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index {}. Supported types are {}.", @@ -93,20 +77,24 @@ public class ExtractedFieldsDetector { getSupportedTypes()); } - List sortedFields = new ArrayList<>(fields); - // We sort the fields to ensure the checksum for each document is deterministic - Collections.sort(sortedFields); - ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); - if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { - extractedFields = fetchFromSourceIfSupported(extractedFields); - if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { - throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " + - "please adjust the index level setting [{}]", extractedFields.getDocValueFields().size(), docValueFieldsLimit, - IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey()); - } + checkNoIgnoredFields(fields); + checkFieldsHaveCompatibleTypes(fields); + checkRequiredFields(fields); + return detectExtractedFields(fields); + } + + private Set getIncludedFields() { + Set fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet()); + removeFieldsUnderResultsField(fields); + FetchSourceContext analyzedFields = config.getAnalyzedFields(); + + // If the user has not explicitly included fields we'll include all compatible fields + if (analyzedFields == null || analyzedFields.includes().length == 0) { + fields.removeAll(IGNORE_FIELDS); + removeFieldsWithIncompatibleTypes(fields); } - extractedFields = fetchBooleanFieldsAsIntegers(extractedFields); - return extractedFields; + includeAndExcludeFields(fields); + return fields; } private void removeFieldsUnderResultsField(Set fields) { @@ -139,31 +127,38 @@ public class ExtractedFieldsDetector { Iterator fieldsIterator = fields.iterator(); while (fieldsIterator.hasNext()) { String field = fieldsIterator.next(); - Map fieldCaps = fieldCapabilitiesResponse.getField(field); - if (fieldCaps == null) { - LOGGER.debug("[{}] Removing field [{}] because it is missing from mappings", config.getId(), field); + if (hasCompatibleType(field) == false) { fieldsIterator.remove(); - } else { - Set fieldTypes = fieldCaps.keySet(); - if (NUMERICAL_TYPES.containsAll(fieldTypes)) { - LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field); - } else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) { - LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field); - } else if (isBoolean(fieldTypes)) { - LOGGER.debug("[{}] field [{}] is compatible as it is boolean", config.getId(), field); - } else { - LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}; supported {}", - config.getId(), field, fieldTypes, getSupportedTypes()); - fieldsIterator.remove(); - } } } } + private boolean hasCompatibleType(String field) { + Map fieldCaps = fieldCapabilitiesResponse.getField(field); + if (fieldCaps == null) { + LOGGER.debug("[{}] incompatible field [{}] because it is missing from mappings", config.getId(), field); + return false; + } + Set fieldTypes = fieldCaps.keySet(); + if (Types.numerical().containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field); + return true; + } else if (config.getAnalysis().supportsCategoricalFields() && Types.categorical().containsAll(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field); + return true; + } else if (isBoolean(fieldTypes)) { + LOGGER.debug("[{}] field [{}] is compatible as it is boolean", config.getId(), field); + return true; + } else { + LOGGER.debug("[{}] incompatible field [{}]; types {}; supported {}", config.getId(), field, fieldTypes, getSupportedTypes()); + return false; + } + } + private Set getSupportedTypes() { - Set supportedTypes = new TreeSet<>(NUMERICAL_TYPES); + Set supportedTypes = new TreeSet<>(Types.numerical()); if (config.getAnalysis().supportsCategoricalFields()) { - supportedTypes.addAll(CATEGORICAL_TYPES); + supportedTypes.addAll(Types.categorical()); } supportedTypes.add(BooleanFieldMapper.CONTENT_TYPE); return supportedTypes; @@ -202,16 +197,61 @@ public class ExtractedFieldsDetector { } } - private void checkRequiredFieldsArePresent(Set fields) { - List missingFields = config.getAnalysis().getRequiredFields() - .stream() - .filter(f -> fields.contains(f) == false) - .collect(Collectors.toList()); - if (missingFields.isEmpty() == false) { - throw ExceptionsHelper.badRequestException("required fields {} are missing", missingFields); + private void checkNoIgnoredFields(Set fields) { + Optional ignoreField = IGNORE_FIELDS.stream().filter(fields::contains).findFirst(); + if (ignoreField.isPresent()) { + throw ExceptionsHelper.badRequestException("field [{}] cannot be analyzed", ignoreField.get()); } } + private void checkFieldsHaveCompatibleTypes(Set fields) { + for (String field : fields) { + Map fieldCaps = fieldCapabilitiesResponse.getField(field); + if (fieldCaps == null) { + throw ExceptionsHelper.badRequestException("no mappings could be found for field [{}]", field); + } + + if (hasCompatibleType(field) == false) { + throw ExceptionsHelper.badRequestException("field [{}] has unsupported type {}. Supported types are {}.", field, + fieldCaps.keySet(), getSupportedTypes()); + } + } + } + + private void checkRequiredFields(Set fields) { + List requiredFields = config.getAnalysis().getRequiredFields(); + for (RequiredField requiredField : requiredFields) { + Map fieldCaps = fieldCapabilitiesResponse.getField(requiredField.getName()); + if (fields.contains(requiredField.getName()) == false || fieldCaps == null || fieldCaps.isEmpty()) { + List requiredFieldNames = requiredFields.stream().map(RequiredField::getName).collect(Collectors.toList()); + throw ExceptionsHelper.badRequestException("required field [{}] is missing; analysis requires fields {}", + requiredField.getName(), requiredFieldNames); + } + Set fieldTypes = fieldCaps.keySet(); + if (requiredField.getTypes().containsAll(fieldTypes) == false) { + throw ExceptionsHelper.badRequestException("invalid types {} for required field [{}]; expected types are {}", + fieldTypes, requiredField.getName(), requiredField.getTypes()); + } + } + } + + private ExtractedFields detectExtractedFields(Set fields) { + List sortedFields = new ArrayList<>(fields); + // We sort the fields to ensure the checksum for each document is deterministic + Collections.sort(sortedFields); + ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); + if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + extractedFields = fetchFromSourceIfSupported(extractedFields); + if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " + + "please adjust the index level setting [{}]", extractedFields.getDocValueFields().size(), docValueFieldsLimit, + IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey()); + } + } + extractedFields = fetchBooleanFieldsAsIntegers(extractedFields); + return extractedFields; + } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); for (ExtractedField field : extractedFields.getDocValueFields()) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index c02f35be520..db381373709 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -120,7 +120,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .addAggregatableField("some_long", "long") .addAggregatableField("some_keyword", "keyword") .addAggregatableField("some_boolean", "boolean") - .addAggregatableField("foo", "keyword") + .addAggregatableField("foo", "double") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( @@ -146,7 +146,71 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("foo"), RESULTS_FIELD, false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); - assertThat(e.getMessage(), equalTo("required fields [foo] are missing")); + assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); + } + + public void testDetect_GivenRegressionAndRequiredFieldExcluded() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "float") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[0], new String[] {"foo"}); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); + } + + public void testDetect_GivenRegressionAndRequiredFieldNotIncluded() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "float") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] {"some_float", "some_keyword"}, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); + } + + public void testDetect_GivenFieldIsBothIncludedAndExcluded() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("foo", "float") + .addAggregatableField("bar", "float") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] {"foo", "bar"}, new String[] {"foo"}); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildOutlierDetectionConfig(analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(1)); + assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), contains("bar")); + } + + public void testDetect_GivenRegressionAndRequiredFieldHasInvalidType() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("foo"), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("invalid types [keyword] for required field [foo]; " + + "expected types are [byte, double, float, half_float, integer, long, scaled_float, short]")); } public void testDetect_GivenIgnoredField() { @@ -161,6 +225,18 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short].")); } + public void testDetect_GivenIncludedIgnoredField() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("_id", "float").build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[]{"_id"}, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildOutlierDetectionConfig(analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("field [_id] cannot be analyzed")); + } + public void testDetect_ShouldSortFieldsAlphabetically() { int fieldCount = randomIntBetween(10, 20); List fields = new ArrayList<>(); @@ -185,7 +261,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(extractedFieldNames, equalTo(sortedFields)); } - public void testDetectedExtractedFields_GivenIncludeWithMissingField() { + public void testDetect_GivenIncludeWithMissingField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("my_field1", "float") .addAggregatableField("my_field2", "float") @@ -200,7 +276,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected")); } - public void testDetectedExtractedFields_GivenExcludeAllValidFields() { + public void testDetect_GivenExcludeAllValidFields() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("my_field1", "float") .addAggregatableField("my_field2", "float") @@ -215,12 +291,11 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short].")); } - public void testDetectedExtractedFields_GivenInclusionsAndExclusions() { + public void testDetect_GivenInclusionsAndExclusions() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("my_field1_nope", "float") .addAggregatableField("my_field1", "float") .addAggregatableField("your_field2", "float") - .addAggregatableField("your_keyword", "keyword") .build(); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); @@ -234,7 +309,25 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); } - public void testDetectedExtractedFields_GivenIndexContainsResultsField() { + public void testDetect_GivenIncludedFieldHasUnsupportedType() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("my_field1_nope", "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + + FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("field [your_keyword] has unsupported type [keyword]. " + + "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short].")); + } + + public void testDetect_GivenIndexContainsResultsField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField(RESULTS_FIELD, "float") .addAggregatableField("my_field1", "float") @@ -250,7 +343,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { "please set a different results_field")); } - public void testDetectedExtractedFields_GivenIndexContainsResultsFieldAndTaskIsRestarting() { + public void testDetect_GivenIndexContainsResultsFieldAndTaskIsRestarting() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField(RESULTS_FIELD + ".outlier_score", "float") .addAggregatableField("my_field1", "float") @@ -267,7 +360,40 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); } - public void testDetectedExtractedFields_NullResultsField() { + public void testDetect_GivenIncludedResultsField() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField(RESULTS_FIELD, "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[]{RESULTS_FIELD}, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildOutlierDetectionConfig(analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + + "please set a different results_field")); + } + + public void testDetect_GivenIncludedResultsFieldAndTaskIsRestarting() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField(RESULTS_FIELD + ".outlier_score", "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[]{RESULTS_FIELD}, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildOutlierDetectionConfig(analyzedFields), RESULTS_FIELD, true, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("No field [ml] could be detected")); + } + + public void testDetect_NullResultsField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField(RESULTS_FIELD, "float") .addAggregatableField("my_field1", "float") @@ -284,7 +410,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(extractedFieldNames, equalTo(Arrays.asList(RESULTS_FIELD, "my_field1", "your_field2"))); } - public void testDetectedExtractedFields_GivenLessFieldsThanDocValuesLimit() { + public void testDetect_GivenLessFieldsThanDocValuesLimit() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("field_1", "float") .addAggregatableField("field_2", "float") @@ -303,7 +429,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); } - public void testDetectedExtractedFields_GivenEqualFieldsToDocValuesLimit() { + public void testDetect_GivenEqualFieldsToDocValuesLimit() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("field_1", "float") .addAggregatableField("field_2", "float") @@ -322,7 +448,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); } - public void testDetectedExtractedFields_GivenMoreFieldsThanDocValuesLimit() { + public void testDetect_GivenMoreFieldsThanDocValuesLimit() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("field_1", "float") .addAggregatableField("field_2", "float")