diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index ddb6a921200..68a312fbcf0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis { private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; + private static final String NUM_CLASSES = "num_classes"; + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + /** + * The max number of classes classification supports + */ + private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; + private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -220,7 +227,7 @@ public class Classification implements DataFrameAnalysis { } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); @@ -229,10 +236,11 @@ public class Classification implements DataFrameAnalysis { if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } - String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable)); + String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable)); if (predictionFieldType != null) { params.put(PREDICTION_FIELD_TYPE, predictionFieldType); } + params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); return params; } @@ -274,7 +282,7 @@ public class Classification implements DataFrameAnalysis { @Override public List getFieldCardinalityConstraints() { // This restriction is due to the fact that currently the C++ backend only supports binomial classification. - return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2)); + return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY)); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 664b38e4fc0..941224dc30a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { /** * @return The analysis parameters as a map - * @param extractedFields map of (name, types) for all the extracted fields + * @param fieldInfo Information about the fields like types and cardinalities */ - Map getParams(Map> extractedFields); + Map getParams(FieldInfo fieldInfo); /** * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) @@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { * Returns the document id for the analysis state */ String getStateDocId(String jobId); + + /** + * Summarizes information about the fields that is necessary for analysis to generate + * the parameters needed for the process configuration. + */ + interface FieldInfo { + + /** + * Returns the types for the given field or {@code null} if the field is unknown + * @param field the field whose types to return + * @return the types for the given field or {@code null} if the field is unknown + */ + @Nullable + Set getTypes(String field); + + /** + * Returns the cardinality of the given field or {@code null} if there is no cardinality for that field + * @param field the field whose cardinality to get + * @return the cardinality of the given field or {@code null} if there is no cardinality for that field + */ + @Nullable + Long getCardinality(String field); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 654d5ba4d1a..2c83afa8780 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis { } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); if (nNeighbors != null) { params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 86f8039090c..d8c490ddcb8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -155,7 +155,7 @@ public class Regression implements DataFrameAnalysis { } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index ef3f4f0082c..5d3f0140d1c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -188,34 +188,45 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase> extractedFields = new HashMap<>(3); - extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE)); - extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName())); - extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE)); + Map> fieldTypes = new HashMap<>(3); + fieldTypes.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE)); + fieldTypes.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName())); + fieldTypes.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE)); + + Map fieldCardinalities = new HashMap<>(); + fieldCardinalities.put("foo", 10L); + fieldCardinalities.put("bar", 20L); + fieldCardinalities.put("baz", 30L); + + DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(fieldTypes, fieldCardinalities); + assertThat( - new Classification("foo").getParams(extractedFields), + new Classification("foo").getParams(fieldInfo), Matchers.>allOf( hasEntry("dependent_variable", "foo"), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "foo_prediction"), - hasEntry("prediction_field_type", "bool"))); + hasEntry("prediction_field_type", "bool"), + hasEntry("num_classes", 10L))); assertThat( - new Classification("bar").getParams(extractedFields), + new Classification("bar").getParams(fieldInfo), Matchers.>allOf( hasEntry("dependent_variable", "bar"), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "bar_prediction"), - hasEntry("prediction_field_type", "int"))); + hasEntry("prediction_field_type", "int"), + hasEntry("num_classes", 20L))); assertThat( - new Classification("baz").getParams(extractedFields), + new Classification("baz").getParams(fieldInfo), Matchers.>allOf( hasEntry("dependent_variable", "baz"), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "baz_prediction"), - hasEntry("prediction_field_type", "string"))); + hasEntry("prediction_field_type", "string"), + hasEntry("num_classes", 30L))); } public void testRequiredFieldsIsNonEmpty() { @@ -229,7 +240,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase> fieldTypes; + private final Map fieldCardinalities; + + private TestFieldInfo(Map> fieldTypes, Map fieldCardinalities) { + this.fieldTypes = fieldTypes; + this.fieldCardinalities = fieldCardinalities; + } + + @Override + public Set getTypes(String field) { + return fieldTypes.get(field); + } + + @Override + public Long getCardinality(String field) { + return fieldCardinalities.get(field); + } + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index e403f4cf8b4..9a3109b366c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -323,9 +323,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } + @AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes") public void testDependentVariableCardinalityTooHighError() throws Exception { initialize("cardinality_too_high"); indexData(sourceIndex, 6, 5, KEYWORD_FIELD); + // Index one more document with a class different than the two already used. client().execute( IndexAction.INSTANCE, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java index 8202c0ef3d2..eb13f2395ef 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields { private final ExtractedField timeField; public TimeBasedExtractedFields(ExtractedField timeField, List allFields) { - super(allFields); + super(allFields, Collections.emptyMap()); if (!allFields.contains(timeField)) { throw new IllegalArgumentException("timeField should also be contained in allFields"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index be5ea6b83ea..463974bcce2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -58,15 +58,15 @@ public class ExtractedFieldsDetector { private final DataFrameAnalyticsConfig config; private final int docValueFieldsLimit; private final FieldCapabilitiesResponse fieldCapabilitiesResponse; - private final Map fieldCardinalities; + private final Map cardinalitiesForFieldsWithConstraints; ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit, - FieldCapabilitiesResponse fieldCapabilitiesResponse, Map fieldCardinalities) { + FieldCapabilitiesResponse fieldCapabilitiesResponse, Map cardinalitiesForFieldsWithConstraints) { this.index = Objects.requireNonNull(index); this.config = Objects.requireNonNull(config); this.docValueFieldsLimit = docValueFieldsLimit; this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); - this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities); + this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints); } public Tuple> detect() { @@ -286,12 +286,13 @@ public class ExtractedFieldsDetector { private void checkFieldsWithCardinalityLimit() { for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) { - constraint.check(fieldCardinalities.get(constraint.getField())); + constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField())); } } private ExtractedFields detectExtractedFields(Set fields, Set fieldSelection) { - ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse); + ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse, + cardinalitiesForFieldsWithConstraints); boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); if (preferSource) { @@ -321,7 +322,7 @@ public class ExtractedFieldsDetector { chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); } } - return new ExtractedFields(new ArrayList<>(nameOrParentToField.values())); + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints); } private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, ExtractedField parent, @@ -372,7 +373,7 @@ public class ExtractedFieldsDetector { for (ExtractedField field : extractedFields.getAllFields()) { adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); } - return new ExtractedFields(adjusted); + return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); } private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) { @@ -389,7 +390,7 @@ public class ExtractedFieldsDetector { adjusted.add(field); } } - return new ExtractedFields(adjusted); + return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); } private void addIncludedFields(ExtractedFields extractedFields, Set fieldSelection) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 714f6309180..0daec5365ff 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -14,10 +14,9 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.Set; -import static java.util.stream.Collectors.toMap; - public class AnalyticsProcessConfig implements ToXContentObject { private static final String JOB_ID = "job_id"; @@ -93,12 +92,31 @@ public class AnalyticsProcessConfig implements ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("name", analysis.getWriteableName()); - builder.field( - "parameters", - analysis.getParams( - extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes)))); + builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields))); builder.endObject(); return builder; } } + + private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo { + + private final ExtractedFields extractedFields; + + AnalysisFieldInfo(ExtractedFields extractedFields) { + this.extractedFields = Objects.requireNonNull(extractedFields); + } + + @Override + public Set getTypes(String field) { + Optional extractedField = extractedFields.getAllFields().stream() + .filter(f -> f.getName().equals(field)) + .findAny(); + return extractedField.isPresent() ? extractedField.get().getTypes() : null; + } + + @Override + public Long getCardinality(String field) { + return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index 3a36bb7ff76..ab314a5d218 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -28,12 +28,14 @@ public class ExtractedFields { private final List allFields; private final List docValueFields; private final String[] sourceFields; + private final Map cardinalitiesForFieldsWithConstraints; - public ExtractedFields(List allFields) { + public ExtractedFields(List allFields, Map cardinalitiesForFieldsWithConstraints) { this.allFields = Collections.unmodifiableList(allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) .toArray(String[]::new); + this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints); } public List getAllFields() { @@ -48,14 +50,20 @@ public class ExtractedFields { return docValueFields; } + public Map getCardinalitiesForFieldsWithConstraints() { + return cardinalitiesForFieldsWithConstraints; + } + private static List filterFields(ExtractedField.Method method, List fields) { return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); } public static ExtractedFields build(Collection allFields, Set scriptFields, - FieldCapabilitiesResponse fieldsCapabilities) { + FieldCapabilitiesResponse fieldsCapabilities, + Map cardinalitiesForFieldsWithConstraints) { ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); - return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList())); + return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()), + cardinalitiesForFieldsWithConstraints); } public static TimeField newTimeField(String name, ExtractedField.Method method) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index 65e33da40f3..c482105de89 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -81,7 +81,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("keyword")), - new DocValueField("field_2", Collections.singleton("keyword")))); + new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap()); scrollSize = 1000; headers = Collections.emptyMap(); @@ -299,7 +299,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 extractedFields = new ExtractedFields(Arrays.asList( (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_2", Collections.singleton("text")))); + (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(false, false); @@ -404,7 +404,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_text", Collections.singleton("text")))); + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(true, true); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 49a302a498b..a7a5784c452 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -294,10 +294,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX, - buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L)); + buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 31L)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); - assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]")); + assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [30] distinct values but there were at least [31]")); } public void testDetect_GivenIgnoredField() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java new file mode 100644 index 00000000000..a4db8de032a --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasKey; + +public class AnalyticsProcessConfigTests extends ESTestCase { + + private String jobId; + private long rows; + private int cols; + private ByteSizeValue memoryLimit; + private int threads; + private String resultsField; + private Set categoricalFields; + + @Before + public void setUpConfigParams() { + jobId = randomAlphaOfLength(10); + rows = randomNonNegativeLong(); + cols = randomIntBetween(1, 42000); + memoryLimit = new ByteSizeValue(randomNonNegativeLong(), ByteSizeUnit.BYTES); + threads = randomIntBetween(1, 8); + resultsField = randomAlphaOfLength(10); + + int categoricalFieldsSize = randomIntBetween(0, 5); + categoricalFields = new HashSet<>(); + for (int i = 0; i < categoricalFieldsSize; i++) { + categoricalFields.add(randomAlphaOfLength(10)); + } + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenOutlierDetection() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap()); + DataFrameAnalysis analysis = new OutlierDetection.Builder().build(); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "outlier_detection")); + assertThat(analysisAsMap, hasKey("parameters")); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenRegression() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap()); + DataFrameAnalysis analysis = new Regression("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "regression")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L)); + DataFrameAnalysis analysis = new Classification("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "classification")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + assertThat(paramsAsMap, hasEntry("prediction_field_type", "string")); + assertThat(paramsAsMap, hasEntry("num_classes", 5)); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L)); + DataFrameAnalysis analysis = new Classification("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "classification")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + assertThat(paramsAsMap, hasEntry("prediction_field_type", "int")); + assertThat(paramsAsMap, hasEntry("num_classes", 8)); + } + + private AnalyticsProcessConfig createProcessConfig(DataFrameAnalysis analysis, ExtractedFields extractedFields) { + return new AnalyticsProcessConfig(jobId, rows, cols, memoryLimit, threads, resultsField, categoricalFields, analysis, + extractedFields); + } + + private static Map toMap(AnalyticsProcessConfig config) throws IOException { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + return XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(builder), false); + } + } + + @SuppressWarnings("unchecked") + private void assertRandomizedFields(Map configAsMap) { + assertThat(configAsMap, hasEntry("job_id", jobId)); + assertThat(configAsMap, hasEntry("rows", rows)); + assertThat(configAsMap, hasEntry("cols", cols)); + assertThat(configAsMap, hasEntry("memory_limit", memoryLimit.getBytes())); + assertThat(configAsMap, hasEntry("threads", threads)); + assertThat(configAsMap, hasEntry("results_field", resultsField)); + assertThat(configAsMap, hasKey("categorical_fields")); + assertThat((List) configAsMap.get("categorical_fields"), containsInAnyOrder(categoricalFields.toArray())); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index 5ac983e7d50..a51eafd1d8b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -32,7 +32,7 @@ public class ExtractedFieldsTests extends ESTestCase { ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( - docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); + docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap()); assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), @@ -54,7 +54,7 @@ public class ExtractedFieldsTests extends ESTestCase { when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), - new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse); + new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap()); assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); @@ -77,7 +77,7 @@ public class ExtractedFieldsTests extends ESTestCase { when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), - Collections.emptySet(), fieldCapabilitiesResponse); + Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()); assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); @@ -119,7 +119,7 @@ public class ExtractedFieldsTests extends ESTestCase { FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( - Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse)); + Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap())); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 81077b1e69a..ab6cfac9515 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -142,7 +142,7 @@ id: "start_given_empty_dest_index" --- -"Test start classification analysis when the dependent variable cardinality is too low or too high": +"Test start classification analysis when the dependent variable cardinality is too low": - do: indices.create: index: index-with-dep-var-with-too-high-card @@ -179,22 +179,3 @@ catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/ ml.start_data_frame_analytics: id: "classification-cardinality-limits" - - - do: - index: - index: index-with-dep-var-with-too-high-card - body: { numeric_field: 2.0, keyword_field: "class_b" } - - - do: - index: - index: index-with-dep-var-with-too-high-card - body: { numeric_field: 3.0, keyword_field: "class_c" } - - - do: - indices.refresh: - index: index-with-dep-var-with-too-high-card - - - do: - catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/ - ml.start_data_frame_analytics: - id: "classification-cardinality-limits"