[7.x][ML] Extend classification to support multiple classes (#53539) (#53597)

Prepares classification analysis to support more than just
two classes. It introduces a new parameter to the process config
which dictates the `num_classes` to the process. It also
changes the max classes limit to `30` provisionally.

Backport of #53539
This commit is contained in:
Dimitris Athanasiou 2020-03-16 15:00:54 +02:00 committed by GitHub
parent a38e5ca8e7
commit 94da4ca3fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 310 additions and 65 deletions

View File

@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis {
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
private static final String NUM_CLASSES = "num_classes";
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
/**
* The max number of classes classification supports
*/
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
@ -220,7 +227,7 @@ public class Classification implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
@ -229,10 +236,11 @@ public class Classification implements DataFrameAnalysis {
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
return params;
}
@ -274,7 +282,7 @@ public class Classification implements DataFrameAnalysis {
@Override
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY));
}
@SuppressWarnings("unchecked")

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
/**
* @return The analysis parameters as a map
* @param extractedFields map of (name, types) for all the extracted fields
* @param fieldInfo Information about the fields like types and cardinalities
*/
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
Map<String, Object> getParams(FieldInfo fieldInfo);
/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
* Returns the document id for the analysis state
*/
String getStateDocId(String jobId);
/**
* Summarizes information about the fields that is necessary for analysis to generate
* the parameters needed for the process configuration.
*/
interface FieldInfo {
/**
* Returns the types for the given field or {@code null} if the field is unknown
* @param field the field whose types to return
* @return the types for the given field or {@code null} if the field is unknown
*/
@Nullable
Set<String> getTypes(String field);
/**
* Returns the cardinality of the given field or {@code null} if there is no cardinality for that field
* @param field the field whose cardinality to get
* @return the cardinality of the given field or {@code null} if there is no cardinality for that field
*/
@Nullable
Long getCardinality(String field);
}
}

View File

@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
if (nNeighbors != null) {
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

View File

@ -155,7 +155,7 @@ public class Regression implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
public Map<String, Object> getParams(FieldInfo fieldInfo) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());

View File

@ -188,34 +188,45 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
}
public void testGetParams() {
Map<String, Set<String>> extractedFields = new HashMap<>(3);
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
Map<String, Set<String>> fieldTypes = new HashMap<>(3);
fieldTypes.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
fieldTypes.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
fieldTypes.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
Map<String, Long> fieldCardinalities = new HashMap<>();
fieldCardinalities.put("foo", 10L);
fieldCardinalities.put("bar", 20L);
fieldCardinalities.put("baz", 30L);
DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(fieldTypes, fieldCardinalities);
assertThat(
new Classification("foo").getParams(extractedFields),
new Classification("foo").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "foo"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "foo_prediction"),
hasEntry("prediction_field_type", "bool")));
hasEntry("prediction_field_type", "bool"),
hasEntry("num_classes", 10L)));
assertThat(
new Classification("bar").getParams(extractedFields),
new Classification("bar").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "bar"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "bar_prediction"),
hasEntry("prediction_field_type", "int")));
hasEntry("prediction_field_type", "int"),
hasEntry("num_classes", 20L)));
assertThat(
new Classification("baz").getParams(extractedFields),
new Classification("baz").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "baz"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "baz_prediction"),
hasEntry("prediction_field_type", "string")));
hasEntry("prediction_field_type", "string"),
hasEntry("num_classes", 30L)));
}
public void testRequiredFieldsIsNonEmpty() {
@ -229,7 +240,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat(constraints.size(), equalTo(1));
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(30L));
}
public void testGetExplicitlyMappedFields() {
@ -328,4 +339,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
return mutateForVersion(instance, version);
}
private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo {
private final Map<String, Set<String>> fieldTypes;
private final Map<String, Long> fieldCardinalities;
private TestFieldInfo(Map<String, Set<String>> fieldTypes, Map<String, Long> fieldCardinalities) {
this.fieldTypes = fieldTypes;
this.fieldCardinalities = fieldCardinalities;
}
@Override
public Set<String> getTypes(String field) {
return fieldTypes.get(field);
}
@Override
public Long getCardinality(String field) {
return fieldCardinalities.get(field);
}
}
}

View File

@ -323,9 +323,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}
@AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes")
public void testDependentVariableCardinalityTooHighError() throws Exception {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used.
client().execute(
IndexAction.INSTANCE,

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields {
private final ExtractedField timeField;
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
super(allFields);
super(allFields, Collections.emptyMap());
if (!allFields.contains(timeField)) {
throw new IllegalArgumentException("timeField should also be contained in allFields");
}

View File

@ -58,15 +58,15 @@ public class ExtractedFieldsDetector {
private final DataFrameAnalyticsConfig config;
private final int docValueFieldsLimit;
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
private final Map<String, Long> fieldCardinalities;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit,
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> fieldCardinalities) {
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.index = Objects.requireNonNull(index);
this.config = Objects.requireNonNull(config);
this.docValueFieldsLimit = docValueFieldsLimit;
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities);
this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints);
}
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
@ -286,12 +286,13 @@ public class ExtractedFieldsDetector {
private void checkFieldsWithCardinalityLimit() {
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
constraint.check(fieldCardinalities.get(constraint.getField()));
constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField()));
}
}
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse);
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
cardinalitiesForFieldsWithConstraints);
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
if (preferSource) {
@ -321,7 +322,7 @@ public class ExtractedFieldsDetector {
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
}
}
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
}
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
@ -372,7 +373,7 @@ public class ExtractedFieldsDetector {
for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
}
return new ExtractedFields(adjusted);
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
}
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
@ -389,7 +390,7 @@ public class ExtractedFieldsDetector {
adjusted.add(field);
}
}
return new ExtractedFields(adjusted);
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
}
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {

View File

@ -14,10 +14,9 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static java.util.stream.Collectors.toMap;
public class AnalyticsProcessConfig implements ToXContentObject {
private static final String JOB_ID = "job_id";
@ -93,12 +92,31 @@ public class AnalyticsProcessConfig implements ToXContentObject {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("name", analysis.getWriteableName());
builder.field(
"parameters",
analysis.getParams(
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields)));
builder.endObject();
return builder;
}
}
private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo {
private final ExtractedFields extractedFields;
AnalysisFieldInfo(ExtractedFields extractedFields) {
this.extractedFields = Objects.requireNonNull(extractedFields);
}
@Override
public Set<String> getTypes(String field) {
Optional<ExtractedField> extractedField = extractedFields.getAllFields().stream()
.filter(f -> f.getName().equals(field))
.findAny();
return extractedField.isPresent() ? extractedField.get().getTypes() : null;
}
@Override
public Long getCardinality(String field) {
return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field);
}
}
}

View File

@ -28,12 +28,14 @@ public class ExtractedFields {
private final List<ExtractedField> allFields;
private final List<ExtractedField> docValueFields;
private final String[] sourceFields;
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
public ExtractedFields(List<ExtractedField> allFields) {
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
this.allFields = Collections.unmodifiableList(allFields);
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
.toArray(String[]::new);
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
}
public List<ExtractedField> getAllFields() {
@ -48,14 +50,20 @@ public class ExtractedFields {
return docValueFields;
}
public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
return cardinalitiesForFieldsWithConstraints;
}
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
}
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
FieldCapabilitiesResponse fieldsCapabilities) {
FieldCapabilitiesResponse fieldsCapabilities,
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()));
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
cardinalitiesForFieldsWithConstraints);
}
public static TimeField newTimeField(String name, ExtractedField.Method method) {

View File

@ -81,7 +81,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
query = QueryBuilders.matchAllQuery();
extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("keyword")),
new DocValueField("field_2", Collections.singleton("keyword"))));
new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap());
scrollSize = 1000;
headers = Collections.emptyMap();
@ -299,7 +299,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
extractedFields = new ExtractedFields(Arrays.asList(
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))));
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(false, false);
@ -404,7 +404,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))));
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true);
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());

View File

@ -294,10 +294,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX,
buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L));
buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 31L));
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]"));
assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [30] distinct values but there were at least [31]"));
}
public void testDetect_GivenIgnoredField() {

View File

@ -0,0 +1,170 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;
public class AnalyticsProcessConfigTests extends ESTestCase {
private String jobId;
private long rows;
private int cols;
private ByteSizeValue memoryLimit;
private int threads;
private String resultsField;
private Set<String> categoricalFields;
@Before
public void setUpConfigParams() {
jobId = randomAlphaOfLength(10);
rows = randomNonNegativeLong();
cols = randomIntBetween(1, 42000);
memoryLimit = new ByteSizeValue(randomNonNegativeLong(), ByteSizeUnit.BYTES);
threads = randomIntBetween(1, 8);
resultsField = randomAlphaOfLength(10);
int categoricalFieldsSize = randomIntBetween(0, 5);
categoricalFields = new HashSet<>();
for (int i = 0; i < categoricalFieldsSize; i++) {
categoricalFields.add(randomAlphaOfLength(10));
}
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenOutlierDetection() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap());
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "outlier_detection"));
assertThat(analysisAsMap, hasKey("parameters"));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenRegression() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap());
DataFrameAnalysis analysis = new Regression("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "regression"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L));
DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "classification"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
assertThat(paramsAsMap, hasEntry("prediction_field_type", "string"));
assertThat(paramsAsMap, hasEntry("num_classes", 5));
}
@SuppressWarnings("unchecked")
public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOException {
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("double")),
new DocValueField("field_2", Collections.singleton("float")),
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L));
DataFrameAnalysis analysis = new Classification("test_dep_var");
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
Map<String, Object> asMap = toMap(processConfig);
assertRandomizedFields(asMap);
assertThat(asMap, hasKey("analysis"));
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
assertThat(analysisAsMap, hasEntry("name", "classification"));
assertThat(analysisAsMap, hasKey("parameters"));
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
assertThat(paramsAsMap, hasEntry("prediction_field_type", "int"));
assertThat(paramsAsMap, hasEntry("num_classes", 8));
}
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalysis analysis, ExtractedFields extractedFields) {
return new AnalyticsProcessConfig(jobId, rows, cols, memoryLimit, threads, resultsField, categoricalFields, analysis,
extractedFields);
}
private static Map<String, Object> toMap(AnalyticsProcessConfig config) throws IOException {
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
return XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(builder), false);
}
}
@SuppressWarnings("unchecked")
private void assertRandomizedFields(Map<String, Object> configAsMap) {
assertThat(configAsMap, hasEntry("job_id", jobId));
assertThat(configAsMap, hasEntry("rows", rows));
assertThat(configAsMap, hasEntry("cols", cols));
assertThat(configAsMap, hasEntry("memory_limit", memoryLimit.getBytes()));
assertThat(configAsMap, hasEntry("threads", threads));
assertThat(configAsMap, hasEntry("results_field", resultsField));
assertThat(configAsMap, hasKey("categorical_fields"));
assertThat((List<String>) configAsMap.get("categorical_fields"), containsInAnyOrder(categoricalFields.toArray()));
}
}

View File

@ -32,7 +32,7 @@ public class ExtractedFieldsTests extends ESTestCase {
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap());
assertThat(extractedFields.getAllFields().size(), equalTo(6));
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
@ -54,7 +54,7 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"),
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse);
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap());
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
@ -77,7 +77,7 @@ public class ExtractedFieldsTests extends ESTestCase {
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"),
Collections.emptySet(), fieldCapabilitiesResponse);
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap());
assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
@ -119,7 +119,7 @@ public class ExtractedFieldsTests extends ESTestCase {
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse));
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()));
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
}

View File

@ -142,7 +142,7 @@
id: "start_given_empty_dest_index"
---
"Test start classification analysis when the dependent variable cardinality is too low or too high":
"Test start classification analysis when the dependent variable cardinality is too low":
- do:
indices.create:
index: index-with-dep-var-with-too-high-card
@ -179,22 +179,3 @@
catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/
ml.start_data_frame_analytics:
id: "classification-cardinality-limits"
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 2.0, keyword_field: "class_b" }
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 3.0, keyword_field: "class_c" }
- do:
indices.refresh:
index: index-with-dep-var-with-too-high-card
- do:
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
ml.start_data_frame_analytics:
id: "classification-cardinality-limits"