[7.x][ML] Extend classification to support multiple classes (#53539) (#53597)

Prepares classification analysis to support more than just
two classes. It introduces a new parameter to the process config
which dictates the `num_classes` to the process. It also
changes the max classes limit to `30` provisionally.

Backport of #53539
This commit is contained in:
Dimitris Athanasiou 2020-03-16 15:00:54 +02:00 committed by GitHub
parent a38e5ca8e7
commit 94da4ca3fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 310 additions and 65 deletions

View File

@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis {
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
private static final String NUM_CLASSES = "num_classes";
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
/**
* The max number of classes classification supports
*/
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) { private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(), NAME.getPreferredName(),
@ -220,7 +227,7 @@ public class Classification implements DataFrameAnalysis {
} }
@Override @Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) { public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams()); params.putAll(boostedTreeParams.getParams());
@ -229,10 +236,11 @@ public class Classification implements DataFrameAnalysis {
if (predictionFieldName != null) { if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
} }
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable)); String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
if (predictionFieldType != null) { if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType); params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
} }
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
return params; return params;
} }
@ -274,7 +282,7 @@ public class Classification implements DataFrameAnalysis {
@Override @Override
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() { public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification. // This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2)); return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.analyses; package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
/** /**
* @return The analysis parameters as a map * @return The analysis parameters as a map
* @param extractedFields map of (name, types) for all the extracted fields * @param fieldInfo Information about the fields like types and cardinalities
*/ */
Map<String, Object> getParams(Map<String, Set<String>> extractedFields); Map<String, Object> getParams(FieldInfo fieldInfo);
/** /**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
* Returns the document id for the analysis state * Returns the document id for the analysis state
*/ */
String getStateDocId(String jobId); String getStateDocId(String jobId);
/**
* Summarizes information about the fields that is necessary for analysis to generate
* the parameters needed for the process configuration.
*/
interface FieldInfo {
/**
* Returns the types for the given field or {@code null} if the field is unknown
* @param field the field whose types to return
* @return the types for the given field or {@code null} if the field is unknown
*/
@Nullable
Set<String> getTypes(String field);
/**
* Returns the cardinality of the given field or {@code null} if there is no cardinality for that field
* @param field the field whose cardinality to get
* @return the cardinality of the given field or {@code null} if there is no cardinality for that field
*/
@Nullable
Long getCardinality(String field);
}
} }

View File

@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis {
} }
@Override @Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) { public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
if (nNeighbors != null) { if (nNeighbors != null) {
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

View File

@ -155,7 +155,7 @@ public class Regression implements DataFrameAnalysis {
} }
@Override @Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) { public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams()); params.putAll(boostedTreeParams.getParams());

View File

@ -188,34 +188,45 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
} }
public void testGetParams() { public void testGetParams() {
Map<String, Set<String>> extractedFields = new HashMap<>(3); Map<String, Set<String>> fieldTypes = new HashMap<>(3);
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE)); fieldTypes.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName())); fieldTypes.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE)); fieldTypes.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
Map<String, Long> fieldCardinalities = new HashMap<>();
fieldCardinalities.put("foo", 10L);
fieldCardinalities.put("bar", 20L);
fieldCardinalities.put("baz", 30L);
DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(fieldTypes, fieldCardinalities);
assertThat( assertThat(
new Classification("foo").getParams(extractedFields), new Classification("foo").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf( Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "foo"), hasEntry("dependent_variable", "foo"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2), hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "foo_prediction"), hasEntry("prediction_field_name", "foo_prediction"),
hasEntry("prediction_field_type", "bool"))); hasEntry("prediction_field_type", "bool"),
hasEntry("num_classes", 10L)));
assertThat( assertThat(
new Classification("bar").getParams(extractedFields), new Classification("bar").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf( Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "bar"), hasEntry("dependent_variable", "bar"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2), hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "bar_prediction"), hasEntry("prediction_field_name", "bar_prediction"),
hasEntry("prediction_field_type", "int"))); hasEntry("prediction_field_type", "int"),
hasEntry("num_classes", 20L)));
assertThat( assertThat(
new Classification("baz").getParams(extractedFields), new Classification("baz").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf( Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "baz"), hasEntry("dependent_variable", "baz"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2), hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "baz_prediction"), hasEntry("prediction_field_name", "baz_prediction"),
hasEntry("prediction_field_type", "string"))); hasEntry("prediction_field_type", "string"),
hasEntry("num_classes", 30L)));
} }
public void testRequiredFieldsIsNonEmpty() { public void testRequiredFieldsIsNonEmpty() {
@ -229,7 +240,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat(constraints.size(), equalTo(1)); assertThat(constraints.size(), equalTo(1));
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable())); assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
assertThat(constraints.get(0).getLowerBound(), equalTo(2L)); assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(2L)); assertThat(constraints.get(0).getUpperBound(), equalTo(30L));
} }
public void testGetExplicitlyMappedFields() { public void testGetExplicitlyMappedFields() {
@ -328,4 +339,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
protected Classification mutateInstanceForVersion(Classification instance, Version version) { protected Classification mutateInstanceForVersion(Classification instance, Version version) {
return mutateForVersion(instance, version); return mutateForVersion(instance, version);
} }
private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo {
private final Map<String, Set<String>> fieldTypes;
private final Map<String, Long> fieldCardinalities;
private TestFieldInfo(Map<String, Set<String>> fieldTypes, Map<String, Long> fieldCardinalities) {
this.fieldTypes = fieldTypes;
this.fieldCardinalities = fieldCardinalities;
}
@Override
public Set<String> getTypes(String field) {
return fieldTypes.get(field);
}
@Override
public Long getCardinality(String field) {
return fieldCardinalities.get(field);
}
}
} }

View File

@ -323,9 +323,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
} }
@AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes")
public void testDependentVariableCardinalityTooHighError() throws Exception { public void testDependentVariableCardinalityTooHighError() throws Exception {
initialize("cardinality_too_high"); initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, KEYWORD_FIELD); indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used. // Index one more document with a class different than the two already used.
client().execute( client().execute(
IndexAction.INSTANCE, IndexAction.INSTANCE,

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields {
private final ExtractedField timeField; private final ExtractedField timeField;
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) { public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
super(allFields); super(allFields, Collections.emptyMap());
if (!allFields.contains(timeField)) { if (!allFields.contains(timeField)) {
throw new IllegalArgumentException("timeField should also be contained in allFields"); throw new IllegalArgumentException("timeField should also be contained in allFields");
} }

View File

@ -58,15 +58,15 @@ public class ExtractedFieldsDetector {
private final DataFrameAnalyticsConfig config; private final DataFrameAnalyticsConfig config;
private final int docValueFieldsLimit; private final int docValueFieldsLimit;
private final FieldCapabilitiesResponse fieldCapabilitiesResponse; private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
private final Map<String, Long> fieldCardinalities; private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit, ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit,
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> fieldCardinalities) { FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.index = Objects.requireNonNull(index); this.index = Objects.requireNonNull(index);
this.config = Objects.requireNonNull(config); this.config = Objects.requireNonNull(config);
this.docValueFieldsLimit = docValueFieldsLimit; this.docValueFieldsLimit = docValueFieldsLimit;
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities); this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints);
} }
public Tuple<ExtractedFields, List<FieldSelection>> detect() { public Tuple<ExtractedFields, List<FieldSelection>> detect() {
@ -286,12 +286,13 @@ public class ExtractedFieldsDetector {
private void checkFieldsWithCardinalityLimit() { private void checkFieldsWithCardinalityLimit() {
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) { for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
constraint.check(fieldCardinalities.get(constraint.getField())); constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField()));
} }
} }
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) { private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse); ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
cardinalitiesForFieldsWithConstraints);
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
if (preferSource) { if (preferSource) {
@ -321,7 +322,7 @@ public class ExtractedFieldsDetector {
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
} }
} }
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values())); return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
} }
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent, private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
@ -372,7 +373,7 @@ public class ExtractedFieldsDetector {
for (ExtractedField field : extractedFields.getAllFields()) { for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
} }
return new ExtractedFields(adjusted); return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
} }
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) { private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
@ -389,7 +390,7 @@ public class ExtractedFieldsDetector {
adjusted.add(field); adjusted.add(field);
} }
} }
return new ExtractedFields(adjusted); return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
} }
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) { private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {

View File

@ -14,10 +14,9 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import static java.util.stream.Collectors.toMap;
public class AnalyticsProcessConfig implements ToXContentObject { public class AnalyticsProcessConfig implements ToXContentObject {
private static final String JOB_ID = "job_id"; private static final String JOB_ID = "job_id";
@ -93,12 +92,31 @@ public class AnalyticsProcessConfig implements ToXContentObject {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field("name", analysis.getWriteableName()); builder.field("name", analysis.getWriteableName());
builder.field( builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields)));
"parameters",
analysis.getParams(
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
builder.endObject(); builder.endObject();
return builder; return builder;
} }
} }
private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo {
private final ExtractedFields extractedFields;
AnalysisFieldInfo(ExtractedFields extractedFields) {
this.extractedFields = Objects.requireNonNull(extractedFields);
}
@Override
public Set<String> getTypes(String field) {
Optional<ExtractedField> extractedField = extractedFields.getAllFields().stream()
.filter(f -> f.getName().equals(field))
.findAny();
return extractedField.isPresent() ? extractedField.get().getTypes() : null;
}
@Override
public Long getCardinality(String field) {
return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field);
}
}
} }

View File

@ -28,12 +28,14 @@ public class ExtractedFields {
private final List<ExtractedField> allFields; private final List<ExtractedField> allFields;
private final List<ExtractedField> docValueFields; private final List<ExtractedField> docValueFields;
private final String[] sourceFields; private final String[] sourceFields;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
public ExtractedFields(List<ExtractedField> allFields) { public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.allFields = Collections.unmodifiableList(allFields); this.allFields = Collections.unmodifiableList(allFields);
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
.toArray(String[]::new); .toArray(String[]::new);
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
} }
public List<ExtractedField> getAllFields() { public List<ExtractedField> getAllFields() {
@ -48,14 +50,20 @@ public class ExtractedFields {
return docValueFields; return docValueFields;
} }
public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
return cardinalitiesForFieldsWithConstraints;
}
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) { private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
} }
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields, public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
FieldCapabilitiesResponse fieldsCapabilities) { FieldCapabilitiesResponse fieldsCapabilities,
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList())); return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
cardinalitiesForFieldsWithConstraints);
} }
public static TimeField newTimeField(String name, ExtractedField.Method method) { public static TimeField newTimeField(String name, ExtractedField.Method method) {

View File

@ -81,7 +81,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
query = QueryBuilders.matchAllQuery(); query = QueryBuilders.matchAllQuery();
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("keyword")), new DocValueField("field_1", Collections.singleton("keyword")),
new DocValueField("field_2", Collections.singleton("keyword")))); new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap());
scrollSize = 1000; scrollSize = 1000;
headers = Collections.emptyMap(); headers = Collections.emptyMap();
@ -299,7 +299,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
extractedFields = new ExtractedFields(Arrays.asList( extractedFields = new ExtractedFields(Arrays.asList(
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_2", Collections.singleton("text")))); (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(false, false); TestExtractor dataExtractor = createExtractor(false, false);
@ -404,7 +404,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_text", Collections.singleton("text")))); (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true); TestExtractor dataExtractor = createExtractor(true, true);
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());

View File

@ -294,10 +294,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build(); .build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX, ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX,
buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L)); buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 31L));
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]")); assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [30] distinct values but there were at least [31]"));
} }
public void testDetect_GivenIgnoredField() { public void testDetect_GivenIgnoredField() {

View File

@ -0,0 +1,170 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;
public class AnalyticsProcessConfigTests extends ESTestCase {
private String jobId;
private long rows;
private int cols;
private ByteSizeValue memoryLimit;
private int threads;
private String resultsField;
private Set<String> categoricalFields;
@Before
public void setUpConfigParams() {
jobId = randomAlphaOfLength(10);
rows = randomNonNegativeLong();
cols = randomIntBetween(1, 42000);
memoryLimit = new ByteSizeValue(randomNonNegativeLong(), ByteSizeUnit.BYTES);
threads = randomIntBetween(1, 8);
resultsField = randomAlphaOfLength(10);
int categoricalFieldsSize = randomIntBetween(0, 5);
categoricalFields = new HashSet<>();
for (int i = 0; i < categoricalFieldsSize; i++) {
categoricalFields.add(randomAlphaOfLength(10));
}
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenOutlierDetection() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap());
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "outlier_detection"));
assertThat(analysisAsMap, hasKey("parameters"));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenRegression() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap());
DataFrameAnalysis analysis = new Regression("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "regression"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L));
DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "classification"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
assertThat(paramsAsMap, hasEntry("prediction_field_type", "string"));
assertThat(paramsAsMap, hasEntry("num_classes", 5));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L));
DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "classification"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
assertThat(paramsAsMap, hasEntry("prediction_field_type", "int"));
assertThat(paramsAsMap, hasEntry("num_classes", 8));
}
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalysis analysis, ExtractedFields extractedFields) {
return new AnalyticsProcessConfig(jobId, rows, cols, memoryLimit, threads, resultsField, categoricalFields, analysis,
extractedFields);
}
private static Map<String, Object> toMap(AnalyticsProcessConfig config) throws IOException {
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
return XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(builder), false);
}
}
@SuppressWarnings("unchecked")
private void assertRandomizedFields(Map<String, Object> configAsMap) {
assertThat(configAsMap, hasEntry("job_id", jobId));
assertThat(configAsMap, hasEntry("rows", rows));
assertThat(configAsMap, hasEntry("cols", cols));
assertThat(configAsMap, hasEntry("memory_limit", memoryLimit.getBytes()));
assertThat(configAsMap, hasEntry("threads", threads));
assertThat(configAsMap, hasEntry("results_field", resultsField));
assertThat(configAsMap, hasKey("categorical_fields"));
assertThat((List<String>) configAsMap.get("categorical_fields"), containsInAnyOrder(categoricalFields.toArray()));
}
}

View File

@ -32,7 +32,7 @@ public class ExtractedFieldsTests extends ESTestCase {
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap());
assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getAllFields().size(), equalTo(6));
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
@ -54,7 +54,7 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"),
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse); new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap());
assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
@ -77,7 +77,7 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"),
Collections.emptySet(), fieldCapabilitiesResponse); Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap());
assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
@ -119,7 +119,7 @@ public class ExtractedFieldsTests extends ESTestCase {
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse)); Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()));
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
} }

View File

@ -142,7 +142,7 @@
id: "start_given_empty_dest_index" id: "start_given_empty_dest_index"
--- ---
"Test start classification analysis when the dependent variable cardinality is too low or too high": "Test start classification analysis when the dependent variable cardinality is too low":
- do: - do:
indices.create: indices.create:
index: index-with-dep-var-with-too-high-card index: index-with-dep-var-with-too-high-card
@ -179,22 +179,3 @@
catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/ catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/
ml.start_data_frame_analytics: ml.start_data_frame_analytics:
id: "classification-cardinality-limits" id: "classification-cardinality-limits"
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 2.0, keyword_field: "class_b" }
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 3.0, keyword_field: "class_c" }
- do:
indices.refresh:
index: index-with-dep-var-with-too-high-card
- do:
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
ml.start_data_frame_analytics:
id: "classification-cardinality-limits"