Prepares classification analysis to support more than just two classes. It introduces a new parameter to the process config which dictates the `num_classes` to the process. It also changes the max classes limit to `30` provisionally. Backport of #53539
This commit is contained in:
parent
a38e5ca8e7
commit
94da4ca3fc
|
@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis {
|
|||
|
||||
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
|
||||
|
||||
private static final String NUM_CLASSES = "num_classes";
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
/**
|
||||
* The max number of classes classification supports
|
||||
*/
|
||||
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||
|
||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
|
@ -220,7 +227,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
public Map<String, Object> getParams(FieldInfo fieldInfo) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
|
@ -229,10 +236,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
|
||||
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
|
||||
if (predictionFieldType != null) {
|
||||
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
|
||||
}
|
||||
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -274,7 +282,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Override
|
||||
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
|
||||
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
|
||||
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
|
||||
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
|
@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
|
||||
/**
|
||||
* @return The analysis parameters as a map
|
||||
* @param extractedFields map of (name, types) for all the extracted fields
|
||||
* @param fieldInfo Information about the fields like types and cardinalities
|
||||
*/
|
||||
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
|
||||
Map<String, Object> getParams(FieldInfo fieldInfo);
|
||||
|
||||
/**
|
||||
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
|
||||
|
@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
* Returns the document id for the analysis state
|
||||
*/
|
||||
String getStateDocId(String jobId);
|
||||
|
||||
/**
|
||||
* Summarizes information about the fields that is necessary for analysis to generate
|
||||
* the parameters needed for the process configuration.
|
||||
*/
|
||||
interface FieldInfo {
|
||||
|
||||
/**
|
||||
* Returns the types for the given field or {@code null} if the field is unknown
|
||||
* @param field the field whose types to return
|
||||
* @return the types for the given field or {@code null} if the field is unknown
|
||||
*/
|
||||
@Nullable
|
||||
Set<String> getTypes(String field);
|
||||
|
||||
/**
|
||||
* Returns the cardinality of the given field or {@code null} if there is no cardinality for that field
|
||||
* @param field the field whose cardinality to get
|
||||
* @return the cardinality of the given field or {@code null} if there is no cardinality for that field
|
||||
*/
|
||||
@Nullable
|
||||
Long getCardinality(String field);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
public Map<String, Object> getParams(FieldInfo fieldInfo) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
if (nNeighbors != null) {
|
||||
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);
|
||||
|
|
|
@ -155,7 +155,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
public Map<String, Object> getParams(FieldInfo fieldInfo) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
|
|
|
@ -188,34 +188,45 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
}
|
||||
|
||||
public void testGetParams() {
|
||||
Map<String, Set<String>> extractedFields = new HashMap<>(3);
|
||||
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
|
||||
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
|
||||
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
|
||||
Map<String, Set<String>> fieldTypes = new HashMap<>(3);
|
||||
fieldTypes.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
|
||||
fieldTypes.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
|
||||
fieldTypes.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
|
||||
|
||||
Map<String, Long> fieldCardinalities = new HashMap<>();
|
||||
fieldCardinalities.put("foo", 10L);
|
||||
fieldCardinalities.put("bar", 20L);
|
||||
fieldCardinalities.put("baz", 30L);
|
||||
|
||||
DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(fieldTypes, fieldCardinalities);
|
||||
|
||||
assertThat(
|
||||
new Classification("foo").getParams(extractedFields),
|
||||
new Classification("foo").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "foo"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "foo_prediction"),
|
||||
hasEntry("prediction_field_type", "bool")));
|
||||
hasEntry("prediction_field_type", "bool"),
|
||||
hasEntry("num_classes", 10L)));
|
||||
assertThat(
|
||||
new Classification("bar").getParams(extractedFields),
|
||||
new Classification("bar").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "bar"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "bar_prediction"),
|
||||
hasEntry("prediction_field_type", "int")));
|
||||
hasEntry("prediction_field_type", "int"),
|
||||
hasEntry("num_classes", 20L)));
|
||||
assertThat(
|
||||
new Classification("baz").getParams(extractedFields),
|
||||
new Classification("baz").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "baz"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "baz_prediction"),
|
||||
hasEntry("prediction_field_type", "string")));
|
||||
hasEntry("prediction_field_type", "string"),
|
||||
hasEntry("num_classes", 30L)));
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
@ -229,7 +240,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
assertThat(constraints.size(), equalTo(1));
|
||||
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
|
||||
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
|
||||
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
|
||||
assertThat(constraints.get(0).getUpperBound(), equalTo(30L));
|
||||
}
|
||||
|
||||
public void testGetExplicitlyMappedFields() {
|
||||
|
@ -328,4 +339,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
|
||||
return mutateForVersion(instance, version);
|
||||
}
|
||||
|
||||
private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo {
|
||||
|
||||
private final Map<String, Set<String>> fieldTypes;
|
||||
private final Map<String, Long> fieldCardinalities;
|
||||
|
||||
private TestFieldInfo(Map<String, Set<String>> fieldTypes, Map<String, Long> fieldCardinalities) {
|
||||
this.fieldTypes = fieldTypes;
|
||||
this.fieldCardinalities = fieldCardinalities;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getTypes(String field) {
|
||||
return fieldTypes.get(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCardinality(String field) {
|
||||
return fieldCardinalities.get(field);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -323,9 +323,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes")
|
||||
public void testDependentVariableCardinalityTooHighError() throws Exception {
|
||||
initialize("cardinality_too_high");
|
||||
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|
||||
|
||||
// Index one more document with a class different than the two already used.
|
||||
client().execute(
|
||||
IndexAction.INSTANCE,
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields {
|
|||
private final ExtractedField timeField;
|
||||
|
||||
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
|
||||
super(allFields);
|
||||
super(allFields, Collections.emptyMap());
|
||||
if (!allFields.contains(timeField)) {
|
||||
throw new IllegalArgumentException("timeField should also be contained in allFields");
|
||||
}
|
||||
|
|
|
@ -58,15 +58,15 @@ public class ExtractedFieldsDetector {
|
|||
private final DataFrameAnalyticsConfig config;
|
||||
private final int docValueFieldsLimit;
|
||||
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
|
||||
private final Map<String, Long> fieldCardinalities;
|
||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||
|
||||
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit,
|
||||
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> fieldCardinalities) {
|
||||
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
this.index = Objects.requireNonNull(index);
|
||||
this.config = Objects.requireNonNull(config);
|
||||
this.docValueFieldsLimit = docValueFieldsLimit;
|
||||
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
|
||||
this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities);
|
||||
this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
|
||||
|
@ -286,12 +286,13 @@ public class ExtractedFieldsDetector {
|
|||
|
||||
private void checkFieldsWithCardinalityLimit() {
|
||||
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
|
||||
constraint.check(fieldCardinalities.get(constraint.getField()));
|
||||
constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField()));
|
||||
}
|
||||
}
|
||||
|
||||
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
|
||||
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse);
|
||||
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
|
||||
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
|
||||
if (preferSource) {
|
||||
|
@ -321,7 +322,7 @@ public class ExtractedFieldsDetector {
|
|||
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
|
||||
}
|
||||
}
|
||||
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
|
||||
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
|
||||
|
@ -372,7 +373,7 @@ public class ExtractedFieldsDetector {
|
|||
for (ExtractedField field : extractedFields.getAllFields()) {
|
||||
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
|
||||
}
|
||||
return new ExtractedFields(adjusted);
|
||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
|
||||
|
@ -389,7 +390,7 @@ public class ExtractedFieldsDetector {
|
|||
adjusted.add(field);
|
||||
}
|
||||
}
|
||||
return new ExtractedFields(adjusted);
|
||||
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {
|
||||
|
|
|
@ -14,10 +14,9 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.stream.Collectors.toMap;
|
||||
|
||||
public class AnalyticsProcessConfig implements ToXContentObject {
|
||||
|
||||
private static final String JOB_ID = "job_id";
|
||||
|
@ -93,12 +92,31 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field("name", analysis.getWriteableName());
|
||||
builder.field(
|
||||
"parameters",
|
||||
analysis.getParams(
|
||||
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
|
||||
builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields)));
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
||||
private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo {
|
||||
|
||||
private final ExtractedFields extractedFields;
|
||||
|
||||
AnalysisFieldInfo(ExtractedFields extractedFields) {
|
||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getTypes(String field) {
|
||||
Optional<ExtractedField> extractedField = extractedFields.getAllFields().stream()
|
||||
.filter(f -> f.getName().equals(field))
|
||||
.findAny();
|
||||
return extractedField.isPresent() ? extractedField.get().getTypes() : null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCardinality(String field) {
|
||||
return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,12 +28,14 @@ public class ExtractedFields {
|
|||
private final List<ExtractedField> allFields;
|
||||
private final List<ExtractedField> docValueFields;
|
||||
private final String[] sourceFields;
|
||||
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
|
||||
|
||||
public ExtractedFields(List<ExtractedField> allFields) {
|
||||
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
this.allFields = Collections.unmodifiableList(allFields);
|
||||
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
|
||||
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
|
||||
.toArray(String[]::new);
|
||||
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
public List<ExtractedField> getAllFields() {
|
||||
|
@ -48,14 +50,20 @@ public class ExtractedFields {
|
|||
return docValueFields;
|
||||
}
|
||||
|
||||
public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
|
||||
return cardinalitiesForFieldsWithConstraints;
|
||||
}
|
||||
|
||||
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
|
||||
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
|
||||
FieldCapabilitiesResponse fieldsCapabilities) {
|
||||
FieldCapabilitiesResponse fieldsCapabilities,
|
||||
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
|
||||
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
|
||||
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()));
|
||||
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
|
||||
cardinalitiesForFieldsWithConstraints);
|
||||
}
|
||||
|
||||
public static TimeField newTimeField(String name, ExtractedField.Method method) {
|
||||
|
|
|
@ -81,7 +81,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
query = QueryBuilders.matchAllQuery();
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("keyword")),
|
||||
new DocValueField("field_2", Collections.singleton("keyword"))));
|
||||
new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap());
|
||||
scrollSize = 1000;
|
||||
headers = Collections.emptyMap();
|
||||
|
||||
|
@ -299,7 +299,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
// Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915
|
||||
extractedFields = new ExtractedFields(Arrays.asList(
|
||||
(ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")),
|
||||
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))));
|
||||
(ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap());
|
||||
|
||||
TestExtractor dataExtractor = createExtractor(false, false);
|
||||
|
||||
|
@ -404,7 +404,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
(ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")),
|
||||
(ExtractedField) new DocValueField("field_long", Collections.singleton("long")),
|
||||
(ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")),
|
||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))));
|
||||
(ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap());
|
||||
TestExtractor dataExtractor = createExtractor(true, true);
|
||||
|
||||
assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty());
|
||||
|
|
|
@ -294,10 +294,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX,
|
||||
buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L));
|
||||
buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 31L));
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect);
|
||||
|
||||
assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]"));
|
||||
assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [30] distinct values but there were at least [31]"));
|
||||
}
|
||||
|
||||
public void testDetect_GivenIgnoredField() {
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
|
||||
public class AnalyticsProcessConfigTests extends ESTestCase {
|
||||
|
||||
private String jobId;
|
||||
private long rows;
|
||||
private int cols;
|
||||
private ByteSizeValue memoryLimit;
|
||||
private int threads;
|
||||
private String resultsField;
|
||||
private Set<String> categoricalFields;
|
||||
|
||||
@Before
|
||||
public void setUpConfigParams() {
|
||||
jobId = randomAlphaOfLength(10);
|
||||
rows = randomNonNegativeLong();
|
||||
cols = randomIntBetween(1, 42000);
|
||||
memoryLimit = new ByteSizeValue(randomNonNegativeLong(), ByteSizeUnit.BYTES);
|
||||
threads = randomIntBetween(1, 8);
|
||||
resultsField = randomAlphaOfLength(10);
|
||||
|
||||
int categoricalFieldsSize = randomIntBetween(0, 5);
|
||||
categoricalFields = new HashSet<>();
|
||||
for (int i = 0; i < categoricalFieldsSize; i++) {
|
||||
categoricalFields.add(randomAlphaOfLength(10));
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testToXContent_GivenOutlierDetection() throws IOException {
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap());
|
||||
DataFrameAnalysis analysis = new OutlierDetection.Builder().build();
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
Map<String, Object> asMap = toMap(processConfig);
|
||||
|
||||
assertRandomizedFields(asMap);
|
||||
|
||||
assertThat(asMap, hasKey("analysis"));
|
||||
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
|
||||
assertThat(analysisAsMap, hasEntry("name", "outlier_detection"));
|
||||
assertThat(analysisAsMap, hasKey("parameters"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testToXContent_GivenRegression() throws IOException {
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap());
|
||||
DataFrameAnalysis analysis = new Regression("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
Map<String, Object> asMap = toMap(processConfig);
|
||||
|
||||
assertRandomizedFields(asMap);
|
||||
|
||||
assertThat(asMap, hasKey("analysis"));
|
||||
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
|
||||
assertThat(analysisAsMap, hasEntry("name", "regression"));
|
||||
assertThat(analysisAsMap, hasKey("parameters"));
|
||||
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
|
||||
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOException {
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L));
|
||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
Map<String, Object> asMap = toMap(processConfig);
|
||||
|
||||
assertRandomizedFields(asMap);
|
||||
|
||||
assertThat(asMap, hasKey("analysis"));
|
||||
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
|
||||
assertThat(analysisAsMap, hasEntry("name", "classification"));
|
||||
assertThat(analysisAsMap, hasKey("parameters"));
|
||||
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
|
||||
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
|
||||
assertThat(paramsAsMap, hasEntry("prediction_field_type", "string"));
|
||||
assertThat(paramsAsMap, hasEntry("num_classes", 5));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOException {
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
new DocValueField("field_1", Collections.singleton("double")),
|
||||
new DocValueField("field_2", Collections.singleton("float")),
|
||||
new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L));
|
||||
DataFrameAnalysis analysis = new Classification("test_dep_var");
|
||||
|
||||
AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields);
|
||||
Map<String, Object> asMap = toMap(processConfig);
|
||||
|
||||
assertRandomizedFields(asMap);
|
||||
|
||||
assertThat(asMap, hasKey("analysis"));
|
||||
Map<String, Object> analysisAsMap = (Map<String, Object>) asMap.get("analysis");
|
||||
assertThat(analysisAsMap, hasEntry("name", "classification"));
|
||||
assertThat(analysisAsMap, hasKey("parameters"));
|
||||
Map<String, Object> paramsAsMap = (Map<String, Object>) analysisAsMap.get("parameters");
|
||||
assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var"));
|
||||
assertThat(paramsAsMap, hasEntry("prediction_field_type", "int"));
|
||||
assertThat(paramsAsMap, hasEntry("num_classes", 8));
|
||||
}
|
||||
|
||||
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalysis analysis, ExtractedFields extractedFields) {
|
||||
return new AnalyticsProcessConfig(jobId, rows, cols, memoryLimit, threads, resultsField, categoricalFields, analysis,
|
||||
extractedFields);
|
||||
}
|
||||
|
||||
private static Map<String, Object> toMap(AnalyticsProcessConfig config) throws IOException {
|
||||
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
|
||||
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
return XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(builder), false);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void assertRandomizedFields(Map<String, Object> configAsMap) {
|
||||
assertThat(configAsMap, hasEntry("job_id", jobId));
|
||||
assertThat(configAsMap, hasEntry("rows", rows));
|
||||
assertThat(configAsMap, hasEntry("cols", cols));
|
||||
assertThat(configAsMap, hasEntry("memory_limit", memoryLimit.getBytes()));
|
||||
assertThat(configAsMap, hasEntry("threads", threads));
|
||||
assertThat(configAsMap, hasEntry("results_field", resultsField));
|
||||
assertThat(configAsMap, hasKey("categorical_fields"));
|
||||
assertThat((List<String>) configAsMap.get("categorical_fields"), containsInAnyOrder(categoricalFields.toArray()));
|
||||
}
|
||||
}
|
|
@ -32,7 +32,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
|
||||
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
|
||||
ExtractedFields extractedFields = new ExtractedFields(Arrays.asList(
|
||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2));
|
||||
docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap());
|
||||
|
||||
assertThat(extractedFields.getAllFields().size(), equalTo(6));
|
||||
assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
|
||||
|
@ -54,7 +54,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps);
|
||||
|
||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"),
|
||||
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse);
|
||||
new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap());
|
||||
|
||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
|
||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
|
||||
|
@ -77,7 +77,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword);
|
||||
|
||||
ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"),
|
||||
Collections.emptySet(), fieldCapabilitiesResponse);
|
||||
Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap());
|
||||
|
||||
assertThat(extractedFields.getDocValueFields().size(), equalTo(1));
|
||||
assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword"));
|
||||
|
@ -119,7 +119,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
|
||||
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build(
|
||||
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse));
|
||||
Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()));
|
||||
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
|
||||
}
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@
|
|||
id: "start_given_empty_dest_index"
|
||||
|
||||
---
|
||||
"Test start classification analysis when the dependent variable cardinality is too low or too high":
|
||||
"Test start classification analysis when the dependent variable cardinality is too low":
|
||||
- do:
|
||||
indices.create:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
|
@ -179,22 +179,3 @@
|
|||
catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/
|
||||
ml.start_data_frame_analytics:
|
||||
id: "classification-cardinality-limits"
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body: { numeric_field: 2.0, keyword_field: "class_b" }
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body: { numeric_field: 3.0, keyword_field: "class_c" }
|
||||
|
||||
- do:
|
||||
indices.refresh:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
|
||||
- do:
|
||||
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
|
||||
ml.start_data_frame_analytics:
|
||||
id: "classification-cardinality-limits"
|
||||
|
|
Loading…
Reference in New Issue