[7.x] Pass `prediction_field_type` to C++ analytics process (#49861) (#49981)

This commit is contained in:
Przemysław Witek 2019-12-09 14:43:01 +01:00 committed by GitHub
parent 049d854360
commit 0965a10468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 313 additions and 127 deletions

View File

@ -67,6 +67,12 @@ public class Classification implements DataFrameAnalysis {
.flatMap(Set::stream)
.collect(Collectors.toSet()));
/**
* Name of the parameter passed down to C++.
* This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction.
*/
private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
/**
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
* This way the user can see if the prediction was made with confidence they need.
@ -154,7 +160,7 @@ public class Classification implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
@ -162,9 +168,30 @@ public class Classification implements DataFrameAnalysis {
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
return params;
}
private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
if (dependentVariableTypes == null) {
return null;
}
if (Types.categorical().containsAll(dependentVariableTypes)) {
return "string";
}
if (Types.bool().containsAll(dependentVariableTypes)) {
return "bool";
}
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
}
return null;
}
@Override
public boolean supportsCategoricalFields() {
return true;

View File

@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
/**
* @return The analysis parameters as a map
* @param extractedFields map of (name, types) for all the extracted fields
*/
Map<String, Object> getParams();
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)

View File

@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
if (nNeighbors != null) {
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

View File

@ -124,7 +124,7 @@ public class Regression implements DataFrameAnalysis {
}
@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());

View File

@ -8,11 +8,20 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.hamcrest.Matchers;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
@ -115,6 +124,34 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
public void testGetParams() {
Map<String, Set<String>> extractedFields = new HashMap<>(3);
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
assertThat(
new Classification("foo").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "foo"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "foo_prediction"),
hasEntry("prediction_field_type", "bool")));
assertThat(
new Classification("bar").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "bar"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "bar_prediction"),
hasEntry("prediction_field_type", "int")));
assertThat(
new Classification("baz").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "baz"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "baz_prediction"),
hasEntry("prediction_field_type", "string")));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}

View File

@ -51,7 +51,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
public void testGetParams_GivenDefaults() {
OutlierDetection outlierDetection = new OutlierDetection.Builder().build();
Map<String, Object> params = outlierDetection.getParams();
Map<String, Object> params = outlierDetection.getParams(null);
assertThat(params.size(), equalTo(3));
assertThat(params.containsKey("compute_feature_influence"), is(true));
assertThat(params.get("compute_feature_influence"), is(true));
@ -71,7 +71,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
.setStandardizationEnabled(false)
.build();
Map<String, Object> params = outlierDetection.getParams();
Map<String, Object> params = outlierDetection.getParams(null);
assertThat(params.size(), equalTo(6));
assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42));

View File

@ -12,7 +12,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
@ -83,6 +85,12 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testGetParams() {
assertThat(
new Regression("foo").getParams(null),
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction")));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.empty;
public class TypesTests extends ESTestCase {
public void testTypes() {
assertThat(Sets.intersection(Types.bool(), Types.categorical()), empty());
assertThat(Sets.intersection(Types.categorical(), Types.numerical()), empty());
assertThat(Sets.intersection(Types.numerical(), Types.bool()), empty());
assertThat(Sets.difference(Types.discreteNumerical(), Types.numerical()), empty());
}
}

View File

@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
private static final String ANIMAL_NAME_FIELD = "animal_name";
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
private static final String NO_LEGS_FIELD = "no_legs";
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
private static final String IS_PREDATOR_FIELD = "predator";
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
@Before
public void setup() {
@ -41,9 +45,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
cleanUp();
}
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -52,10 +56,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
}
public void testEvaluate_MulticlassClassification_Accuracy() {
public void testEvaluate_Accuracy_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -74,11 +78,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
}
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
public void testEvaluate_Accuracy_IntegerField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("1", 15, 1.0 / 15),
new Accuracy.ActualClass("2", 15, 2.0 / 15),
new Accuracy.ActualClass("3", 15, 3.0 / 15),
new Accuracy.ActualClass("4", 15, 4.0 / 15),
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
}
public void testEvaluate_Accuracy_BooleanField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("true", 45, 27.0 / 45),
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -137,11 +180,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -168,20 +211,30 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
private static void indexAnimalsData(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
.addMapping("_doc",
ANIMAL_NAME_FIELD, "type=keyword",
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
NO_LEGS_FIELD, "type=integer",
NO_LEGS_PREDICTION_FIELD, "type=integer",
IS_PREDATOR_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
.get();
List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < classNames.size(); i++) {
for (int j = 0; j < classNames.size(); j++) {
for (int i = 0; i < animalNames.size(); i++) {
for (int j = 0; j < animalNames.size(); j++) {
for (int k = 0; k < j + 1; k++) {
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(
ACTUAL_CLASS_FIELD, classNames.get(i),
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
ANIMAL_NAME_FIELD, animalNames.get(i),
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
NO_LEGS_FIELD, i + 1,
NO_LEGS_PREDICTION_FIELD, j + 1,
IS_PREDATOR_FIELD, i % 2 == 0,
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
}
}
}

View File

@ -20,9 +20,8 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.junit.After;
import java.util.ArrayList;
@ -30,7 +29,6 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf;
@ -88,7 +86,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
}
assertProgress(jobId, 100, 100, 100, 100);
@ -102,7 +100,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
@ -128,7 +126,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
}
assertProgress(jobId, 100, 100, 100, 100);
@ -142,11 +140,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
}
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
String jobId, String dependentVariable, List<T> dependentVariableValues) throws Exception {
initialize(jobId);
String predictedClassField = dependentVariable + "_prediction";
indexData(sourceIndex, 300, 0, dependentVariable);
@ -175,9 +173,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey(predictedClassField), is(true));
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
@SuppressWarnings("unchecked")
T predictedClassValue = (T) resultsObject.get(predictedClassField);
assertThat(predictedClassValue, is(in(dependentVariableValues)));
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
@ -201,33 +200,32 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis");
assertEvaluation(
dependentVariable,
dependentVariableValues.stream().map(String::valueOf).collect(toList()),
"ml." + predictedClassField);
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES);
assertEvaluation(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "ml.discrete-numerical-field_prediction");
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class,
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES));
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES);
assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction");
}
public void testDependentVariableCardinalityTooHighError() {
@ -317,25 +315,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
return resultsObject;
}
@SuppressWarnings("unchecked")
private static <T> void assertTopClasses(
Map<String, Object> resultsObject,
int numTopClasses,
String dependentVariable,
List<T> dependentVariableValues,
Function<String, T> parser) {
List<T> dependentVariableValues) {
assertThat(resultsObject.containsKey("top_classes"), is(true));
@SuppressWarnings("unchecked")
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
assertThat(topClasses, hasSize(numTopClasses));
List<String> classNames = new ArrayList<>(topClasses.size());
List<T> classNames = new ArrayList<>(topClasses.size());
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
for (Map<String, Object> topClass : topClasses) {
assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability")));
classNames.add((String) topClass.get("class_name"));
classNames.add((T) topClass.get("class_name"));
classProbabilities.add((Double) topClass.get("class_probability"));
}
// Assert that all the predicted class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues))));
// Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
@ -344,25 +341,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
}
private void assertEvaluation(String dependentVariable, List<String> dependentVariableValues, String predictedClassField) {
private <T> void assertEvaluation(String dependentVariable, List<T> dependentVariableValues, String predictedClassField) {
List<String> dependentVariableValuesAsStrings = dependentVariableValues.stream().map(String::valueOf).collect(toList());
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
destIndex,
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
dependentVariable, predictedClassField, null));
dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
{ // Accuracy
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
assertThat(
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
equalTo(dependentVariableValuesAsStrings));
actualClasses.forEach(
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
}
{ // MulticlassConfusionMatrix
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
List<ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
assertThat(actualClasses.stream().map(ActualClass::getActualClass).collect(toList()), equalTo(dependentVariableValues));
for (ActualClass actualClass : actualClasses) {
List<MulticlassConfusionMatrix.ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
assertThat(
actualClasses.stream().map(MulticlassConfusionMatrix.ActualClass::getActualClass).collect(toList()),
equalTo(dependentVariableValuesAsStrings));
for (MulticlassConfusionMatrix.ActualClass actualClass : actualClasses) {
assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L));
assertThat(
actualClass.getPredictedClasses().stream().map(PredictedClass::getPredictedClass).collect(toList()),
equalTo(dependentVariableValues));
actualClass.getPredictedClasses().stream()
.map(MulticlassConfusionMatrix.PredictedClass::getPredictedClass)
.collect(toList()),
equalTo(dependentVariableValuesAsStrings));
}
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
}
}

View File

@ -56,6 +56,10 @@ public class DataFrameDataExtractorFactory {
return new DataFrameDataExtractor(client, context);
}
public ExtractedFields getExtractedFields() {
return extractedFields;
}
private QueryBuilder createQuery() {
BoolQueryBuilder query = QueryBuilders.boolQuery();
query.filter(sourceQuery);

View File

@ -379,14 +379,12 @@ public class ExtractedFieldsDetector {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getAllFields()) {
if (isBoolean(field.getTypes())) {
if (config.getAnalysis().getAllowedCategoricalTypes(field.getName()).contains(BooleanFieldMapper.CONTENT_TYPE)) {
// We convert boolean field to string if it is a categorical dependent variable
adjusted.add(ExtractedFields.applyBooleanMapping(field, Boolean.TRUE.toString(), Boolean.FALSE.toString()));
} else {
// We convert boolean fields to integers with values 0, 1 as this is the preferred
// way to consume such features in the analytics process.
adjusted.add(ExtractedFields.applyBooleanMapping(field, 1, 0));
}
// way to consume such features in the analytics process regardless of:
// - analysis type
// - whether or not the field is categorical
// - whether or not the field is a dependent variable
adjusted.add(ExtractedFields.applyBooleanMapping(field));
} else {
adjusted.add(field);
}

View File

@ -9,11 +9,15 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
import static java.util.stream.Collectors.toMap;
public class AnalyticsProcessConfig implements ToXContentObject {
private static final String JOB_ID = "job_id";
@ -33,9 +37,10 @@ public class AnalyticsProcessConfig implements ToXContentObject {
private final String resultsField;
private final Set<String> categoricalFields;
private final DataFrameAnalysis analysis;
private final ExtractedFields extractedFields;
public AnalyticsProcessConfig(String jobId, long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField,
Set<String> categoricalFields, DataFrameAnalysis analysis) {
Set<String> categoricalFields, DataFrameAnalysis analysis, ExtractedFields extractedFields) {
this.jobId = Objects.requireNonNull(jobId);
this.rows = rows;
this.cols = cols;
@ -44,6 +49,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
this.resultsField = Objects.requireNonNull(resultsField);
this.categoricalFields = Objects.requireNonNull(categoricalFields);
this.analysis = Objects.requireNonNull(analysis);
this.extractedFields = Objects.requireNonNull(extractedFields);
}
public String jobId() {
@ -68,7 +74,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
builder.field(THREADS, threads);
builder.field(RESULTS_FIELD, resultsField);
builder.field(CATEGORICAL_FIELDS, categoricalFields);
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis));
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis, extractedFields));
builder.endObject();
return builder;
}
@ -76,16 +82,21 @@ public class AnalyticsProcessConfig implements ToXContentObject {
private static class DataFrameAnalysisWrapper implements ToXContentObject {
private final DataFrameAnalysis analysis;
private final ExtractedFields extractedFields;
private DataFrameAnalysisWrapper(DataFrameAnalysis analysis) {
private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields extractedFields) {
this.analysis = analysis;
this.extractedFields = extractedFields;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("name", analysis.getWriteableName());
builder.field("parameters", analysis.getParams());
builder.field(
"parameters",
analysis.getParams(
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
builder.endObject();
return builder;
}

View File

@ -33,6 +33,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@ -373,7 +374,8 @@ public class AnalyticsProcessManager {
}
dataExtractor = dataExtractorFactory.newExtractor(false);
AnalyticsProcessConfig analyticsProcessConfig = createProcessConfig(config, dataExtractor);
AnalyticsProcessConfig analyticsProcessConfig =
createProcessConfig(config, dataExtractor, dataExtractorFactory.getExtractedFields());
LOGGER.trace("[{}] creating analytics process with config [{}]", config.getId(), Strings.toString(analyticsProcessConfig));
// If we have no rows, that means there is no data so no point in starting the native process
// just finish the task
@ -389,11 +391,20 @@ public class AnalyticsProcessManager {
return true;
}
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
private AnalyticsProcessConfig createProcessConfig(
DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor, ExtractedFields extractedFields) {
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols,
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(
config.getId(),
dataSummary.rows,
dataSummary.cols,
config.getModelMemoryLimit(),
1,
config.getDest().getResultsField(),
categoricalFields,
config.getAnalysis(),
extractedFields);
return processConfig;
}
}

View File

@ -79,7 +79,8 @@ public class MemoryUsageEstimationProcessManager {
1,
"",
categoricalFields,
config.getAnalysis());
config.getAnalysis(),
dataExtractorFactory.getExtractedFields());
AnalyticsProcess<MemoryUsageEstimationResult> process =
processFactory.createAnalyticsProcess(
config,

View File

@ -62,8 +62,8 @@ public class ExtractedFields {
return new TimeField(name, method);
}
public static <T> ExtractedField applyBooleanMapping(ExtractedField field, T trueValue, T falseValue) {
return new BooleanMapper<>(field, trueValue, falseValue);
public static ExtractedField applyBooleanMapping(ExtractedField field) {
return new BooleanMapper<>(field, 1, 0);
}
public static class ExtractionMethodDetector {

View File

@ -36,6 +36,7 @@ import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -57,7 +58,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields, hasSize(1));
assertThat(allFields.get(0).getName(), equalTo("some_float"));
assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
@ -75,7 +76,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields, hasSize(1));
assertThat(allFields.get(0).getName(), equalTo("some_number"));
assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
@ -121,7 +122,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(3));
assertThat(allFields, hasSize(3));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()),
containsInAnyOrder("some_float", "some_long", "some_boolean"));
assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()),
@ -150,7 +151,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(5));
assertThat(allFields, hasSize(5));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
containsInAnyOrder("foo", "some_float", "some_keyword", "some_long", "some_boolean"));
assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()),
@ -223,7 +224,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields, hasSize(1));
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), contains("bar"));
assertFieldSelectionContains(fieldExtraction.v2(),
@ -329,7 +330,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields, hasSize(1));
assertThat(allFields.get(0).getName(), equalTo("numeric"));
assertFieldSelectionContains(fieldExtraction.v2(),
@ -565,23 +566,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
contains(equalTo(ExtractedField.Method.SOURCE)));
}
public void testDetect_GivenBooleanField_BooleanMappedAsInteger() {
private void testDetect_GivenBooleanField(DataFrameAnalyticsConfig config, boolean isRequired, FieldSelection.FeatureType featureType) {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_boolean", "boolean")
.addAggregatableField("some_integer", "integer")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap());
SOURCE_INDEX, config, false, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
assertThat(allFields, hasSize(2));
ExtractedField booleanField = allFields.get(0);
assertThat(booleanField.getTypes(), contains("boolean"));
assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
assertFieldSelectionContains(fieldExtraction.v2(),
FieldSelection.included("some_boolean", Collections.singleton("boolean"), false, FieldSelection.FeatureType.NUMERICAL)
assertFieldSelectionContains(fieldExtraction.v2().subList(0, 1),
FieldSelection.included("some_boolean", Collections.singleton("boolean"), isRequired, featureType)
);
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
@ -594,34 +596,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(booleanField.value(hit), arrayContaining(0, 1, 0));
}
public void testDetect_GivenBooleanField_BooleanMappedAsString() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("some_boolean", "boolean")
.build();
public void testDetect_GivenBooleanField_OutlierDetection() {
// some_boolean is a non-required, numerical feature in outlier detection analysis
testDetect_GivenBooleanField(buildOutlierDetectionConfig(), false, FieldSelection.FeatureType.NUMERICAL);
}
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("some_boolean"), false, 100, fieldCapabilities,
Collections.singletonMap("some_boolean", 2L));
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
public void testDetect_GivenBooleanField_Regression() {
// some_boolean is a non-required, numerical feature in regression analysis
testDetect_GivenBooleanField(buildRegressionConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL);
}
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(1));
ExtractedField booleanField = allFields.get(0);
assertThat(booleanField.getTypes(), contains("boolean"));
assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
public void testDetect_GivenBooleanField_Classification_BooleanIsFeature() {
// some_boolean is a non-required, numerical feature in classification analysis
testDetect_GivenBooleanField(buildClassificationConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL);
}
assertFieldSelectionContains(fieldExtraction.v2(),
FieldSelection.included("some_boolean", Collections.singleton("boolean"), true, FieldSelection.FeatureType.CATEGORICAL)
);
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
assertThat(booleanField.value(hit), arrayContaining("true"));
hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
assertThat(booleanField.value(hit), arrayContaining("false"));
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
public void testDetect_GivenBooleanField_Classification_BooleanIsDependentVariable() {
// some_boolean is a required, categorical dependent variable in classification analysis
testDetect_GivenBooleanField(buildClassificationConfig("some_boolean"), true, FieldSelection.FeatureType.CATEGORICAL);
}
public void testDetect_GivenMultiFields() {
@ -640,7 +632,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("a_float"), true, 100, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(5));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(5));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
@ -671,7 +663,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildClassificationConfig("field_1"), true, 100, fieldCapabilities, Collections.singletonMap("field_1", 2L));
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
@ -696,7 +688,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Collections.singletonMap("field_1.keyword", 2L));
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
@ -722,7 +714,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
@ -748,7 +740,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 0, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
@ -773,7 +765,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("field_2.double"), true, 100, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
@ -798,7 +790,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
@ -823,7 +815,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
SOURCE_INDEX, buildRegressionConfig("field_2"), false, 100, fieldCapabilities, Collections.emptyMap());
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
@ -849,7 +841,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(2));
assertThat(allFields, hasSize(2));
assertThat(allFields.get(0).getName(), equalTo("field_11"));
assertThat(allFields.get(1).getName(), equalTo("field_12"));
@ -872,7 +864,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
assertThat(allFields.size(), equalTo(2));
assertThat(allFields, hasSize(2));
assertThat(allFields.get(0).getName(), equalTo("field_21"));
assertThat(allFields.get(1).getName(), equalTo("field_22"));
@ -914,7 +906,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
* We assert each field individually to get useful error messages in case of failure
*/
private static void assertFieldSelectionContains(List<FieldSelection> actual, FieldSelection... expected) {
assertThat(actual.size(), equalTo(expected.length));
assertThat(actual, hasSize(expected.length));
for (int i = 0; i < expected.length; i++) {
assertThat("i = " + i, actual.get(i).getName(), equalTo(expected[i].getName()));
assertThat("i = " + i, actual.get(i).getMappingTypes(), equalTo(expected[i].getMappingTypes()));

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
@ -95,6 +96,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
finishHandler = mock(Consumer.class);
exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
@ -219,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
private void givenDataFrameRows(int rows) {
AnalyticsProcessConfig config = new AnalyticsProcessConfig(
"job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class));
"job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class),
mock(ExtractedFields.class));
when(process.getConfig()).thenReturn(config);
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
@ -70,6 +71,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
listener = mock(ActionListener.class);
resultCaptor = ArgumentCaptor.forClass(MemoryUsageEstimationResult.class);

View File

@ -101,7 +101,7 @@ public class ExtractedFieldsTests extends ESTestCase {
public void testApplyBooleanMapping() {
DocValueField aBool = new DocValueField("a_bool", Collections.singleton("boolean"));
ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool, 1, 0);
ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool);
SearchHit hitTrue = new SearchHitBuilder(42).addField("a_bool", true).build();
SearchHit hitFalse = new SearchHitBuilder(42).addField("a_bool", false).build();