diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 571a619ac13..3b4c8c67f63 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -67,6 +67,12 @@ public class Classification implements DataFrameAnalysis { .flatMap(Set::stream) .collect(Collectors.toSet())); + /** + * Name of the parameter passed down to C++. + * This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction. + */ + private static final String PREDICTION_FIELD_TYPE = "prediction_field_type"; + /** * As long as we only support binary classification it makes sense to always report both classes with their probabilities. * This way the user can see if the prediction was made with confidence they need. @@ -154,7 +160,7 @@ public class Classification implements DataFrameAnalysis { } @Override - public Map getParams() { + public Map getParams(Map> extractedFields) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); @@ -162,9 +168,30 @@ public class Classification implements DataFrameAnalysis { if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } + String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable)); + if (predictionFieldType != null) { + params.put(PREDICTION_FIELD_TYPE, predictionFieldType); + } return params; } + private static String getPredictionFieldType(Set dependentVariableTypes) { + if (dependentVariableTypes == null) { + return null; + } + if (Types.categorical().containsAll(dependentVariableTypes)) { + return "string"; + } + if (Types.bool().containsAll(dependentVariableTypes)) { + return "bool"; + } + if (Types.discreteNumerical().containsAll(dependentVariableTypes)) { + // C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers. + return "int"; + } + return null; + } + @Override public boolean supportsCategoricalFields() { return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 0ca32cde402..d0af0a452a4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { /** * @return The analysis parameters as a map + * @param extractedFields map of (name, types) for all the extracted fields */ - Map getParams(); + Map getParams(Map> extractedFields); /** * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index d4cefe884b5..70b3cfb9fe2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis { } @Override - public Map getParams() { + public Map getParams(Map> extractedFields) { Map params = new HashMap<>(); if (nNeighbors != null) { params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 6fa163dd65c..01388f01d80 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -124,7 +124,7 @@ public class Regression implements DataFrameAnalysis { } @Override - public Map getParams() { + public Map getParams(Map> extractedFields) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 8306d08af79..a0c4f31e744 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -8,11 +8,20 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.BooleanFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.hamcrest.Matchers; import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -115,6 +124,34 @@ public class ClassificationTests extends AbstractSerializingTestCase> extractedFields = new HashMap<>(3); + extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE)); + extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName())); + extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE)); + assertThat( + new Classification("foo").getParams(extractedFields), + Matchers.>allOf( + hasEntry("dependent_variable", "foo"), + hasEntry("num_top_classes", 2), + hasEntry("prediction_field_name", "foo_prediction"), + hasEntry("prediction_field_type", "bool"))); + assertThat( + new Classification("bar").getParams(extractedFields), + Matchers.>allOf( + hasEntry("dependent_variable", "bar"), + hasEntry("num_top_classes", 2), + hasEntry("prediction_field_name", "bar_prediction"), + hasEntry("prediction_field_type", "int"))); + assertThat( + new Classification("baz").getParams(extractedFields), + Matchers.>allOf( + hasEntry("dependent_variable", "baz"), + hasEntry("num_top_classes", 2), + hasEntry("prediction_field_name", "baz_prediction"), + hasEntry("prediction_field_type", "string"))); + } + public void testFieldCardinalityLimitsIsNonNull() { assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java index b1181714517..c35b9a3bad1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -51,7 +51,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase params = outlierDetection.getParams(); + Map params = outlierDetection.getParams(null); assertThat(params.size(), equalTo(3)); assertThat(params.containsKey("compute_feature_influence"), is(true)); assertThat(params.get("compute_feature_influence"), is(true)); @@ -71,7 +71,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase params = outlierDetection.getParams(); + Map params = outlierDetection.getParams(null); assertThat(params.size(), equalTo(6)); assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 089f29e53cb..01c030d2c84 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -12,7 +12,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -83,6 +85,12 @@ public class RegressionTests extends AbstractSerializingTestCase { assertThat(regression.getTrainingPercent(), equalTo(100.0)); } + public void testGetParams() { + assertThat( + new Regression("foo").getParams(null), + allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction"))); + } + public void testFieldCardinalityLimitsIsNonNull() { assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/TypesTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/TypesTests.java new file mode 100644 index 00000000000..beac3bcae94 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/TypesTests.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.empty; + +public class TypesTests extends ESTestCase { + + public void testTypes() { + assertThat(Sets.intersection(Types.bool(), Types.categorical()), empty()); + assertThat(Sets.intersection(Types.categorical(), Types.numerical()), empty()); + assertThat(Sets.intersection(Types.numerical(), Types.bool()), empty()); + assertThat(Sets.difference(Types.discreteNumerical(), Types.numerical()), empty()); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index c876174d290..9f9db808440 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; - private static final String ACTUAL_CLASS_FIELD = "actual_class_field"; - private static final String PREDICTED_CLASS_FIELD = "predicted_class_field"; + private static final String ANIMAL_NAME_FIELD = "animal_name"; + private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; + private static final String NO_LEGS_FIELD = "no_legs"; + private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; + private static final String IS_PREDATOR_FIELD = "predator"; + private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; @Before public void setup() { @@ -41,9 +45,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT cleanUp(); } - public void testEvaluate_MulticlassClassification_DefaultMetrics() { + public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -52,10 +56,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); } - public void testEvaluate_MulticlassClassification_Accuracy() { + public void testEvaluate_Accuracy_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy()))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -74,11 +78,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); } - public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() { + public void testEvaluate_Accuracy_IntegerField() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + accuracyResult.getActualClasses(), + equalTo(Arrays.asList( + new Accuracy.ActualClass("1", 15, 1.0 / 15), + new Accuracy.ActualClass("2", 15, 2.0 / 15), + new Accuracy.ActualClass("3", 15, 3.0 / 15), + new Accuracy.ActualClass("4", 15, 4.0 / 15), + new Accuracy.ActualClass("5", 15, 5.0 / 15)))); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); + } + + public void testEvaluate_Accuracy_BooleanField() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + accuracyResult.getActualClasses(), + equalTo(Arrays.asList( + new Accuracy.ActualClass("true", 45, 27.0 / 45), + new Accuracy.ActualClass("false", 30, 18.0 / 30)))); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); + } + + public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -137,11 +180,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } - public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { + public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -168,20 +211,30 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT private static void indexAnimalsData(String indexName) { client().admin().indices().prepareCreate(indexName) - .addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword") + .addMapping("_doc", + ANIMAL_NAME_FIELD, "type=keyword", + ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", + NO_LEGS_FIELD, "type=integer", + NO_LEGS_PREDICTION_FIELD, "type=integer", + IS_PREDATOR_FIELD, "type=boolean", + IS_PREDATOR_PREDICTION_FIELD, "type=boolean") .get(); - List classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); + List animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < classNames.size(); i++) { - for (int j = 0; j < classNames.size(); j++) { + for (int i = 0; i < animalNames.size(); i++) { + for (int j = 0; j < animalNames.size(); j++) { for (int k = 0; k < j + 1; k++) { bulkRequestBuilder.add( new IndexRequest(indexName) .source( - ACTUAL_CLASS_FIELD, classNames.get(i), - PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size()))); + ANIMAL_NAME_FIELD, animalNames.get(i), + ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), + NO_LEGS_FIELD, i + 1, + NO_LEGS_PREDICTION_FIELD, j + 1, + IS_PREDATOR_FIELD, i % 2 == 0, + IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); } } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index a7721f667b7..b86c5ac7a6c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -20,9 +20,8 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.junit.After; import java.util.ArrayList; @@ -30,7 +29,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.Function; import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.allOf; @@ -88,7 +86,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); - assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } assertProgress(jobId, 100, 100, 100, 100); @@ -102,7 +100,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { @@ -128,7 +126,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(true)); - assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } assertProgress(jobId, 100, 100, 100, 100); @@ -142,11 +140,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - String jobId, String dependentVariable, List dependentVariableValues, Function parser) throws Exception { + String jobId, String dependentVariable, List dependentVariableValues) throws Exception { initialize(jobId); String predictedClassField = dependentVariable + "_prediction"; indexData(sourceIndex, 300, 0, dependentVariable); @@ -175,9 +173,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); assertThat(resultsObject.containsKey(predictedClassField), is(true)); - T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField)); + @SuppressWarnings("unchecked") + T predictedClassValue = (T) resultsObject.get(predictedClassField); assertThat(predictedClassValue, is(in(dependentVariableValues))); - assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser); + assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues); assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results @@ -201,33 +200,32 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertEvaluation( - dependentVariable, - dependentVariableValues.stream().map(String::valueOf).collect(toList()), - "ml." + predictedClassField); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); + "classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf); + "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES); + assertEvaluation(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "ml.discrete-numerical-field_prediction"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf)); + "classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES)); assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf); + "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES); + assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction"); } public void testDependentVariableCardinalityTooHighError() { @@ -317,25 +315,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { return resultsObject; } + @SuppressWarnings("unchecked") private static void assertTopClasses( Map resultsObject, int numTopClasses, String dependentVariable, - List dependentVariableValues, - Function parser) { + List dependentVariableValues) { assertThat(resultsObject.containsKey("top_classes"), is(true)); - @SuppressWarnings("unchecked") List> topClasses = (List>) resultsObject.get("top_classes"); assertThat(topClasses, hasSize(numTopClasses)); - List classNames = new ArrayList<>(topClasses.size()); + List classNames = new ArrayList<>(topClasses.size()); List classProbabilities = new ArrayList<>(topClasses.size()); for (Map topClass : topClasses) { assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability"))); - classNames.add((String) topClass.get("class_name")); + classNames.add((T) topClass.get("class_name")); classProbabilities.add((Double) topClass.get("class_probability")); } // Assert that all the predicted class names come from the set of dependent variable values. - classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues)))); + classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues)))); // Assert that the first class listed in top classes is the same as the predicted class. assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction"))); // Assert that all the class probabilities lie within [0, 1] interval. @@ -344,25 +341,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true)); } - private void assertEvaluation(String dependentVariable, List dependentVariableValues, String predictedClassField) { + private void assertEvaluation(String dependentVariable, List dependentVariableValues, String predictedClassField) { + List dependentVariableValuesAsStrings = dependentVariableValues.stream().map(String::valueOf).collect(toList()); EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( destIndex, new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification( - dependentVariable, predictedClassField, null)); + dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - MulticlassConfusionMatrix.Result confusionMatrixResult = - (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); - List actualClasses = confusionMatrixResult.getConfusionMatrix(); - assertThat(actualClasses.stream().map(ActualClass::getActualClass).collect(toList()), equalTo(dependentVariableValues)); - for (ActualClass actualClass : actualClasses) { - assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2)); + + { // Accuracy + Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + List actualClasses = accuracyResult.getActualClasses(); assertThat( - actualClass.getPredictedClasses().stream().map(PredictedClass::getPredictedClass).collect(toList()), - equalTo(dependentVariableValues)); + actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()), + equalTo(dependentVariableValuesAsStrings)); + actualClasses.forEach( + actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); + } + + { // MulticlassConfusionMatrix + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + List actualClasses = confusionMatrixResult.getConfusionMatrix(); + assertThat( + actualClasses.stream().map(MulticlassConfusionMatrix.ActualClass::getActualClass).collect(toList()), + equalTo(dependentVariableValuesAsStrings)); + for (MulticlassConfusionMatrix.ActualClass actualClass : actualClasses) { + assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L)); + assertThat( + actualClass.getPredictedClasses().stream() + .map(MulticlassConfusionMatrix.PredictedClass::getPredictedClass) + .collect(toList()), + equalTo(dependentVariableValuesAsStrings)); + } + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } - assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 1c060f17864..c7d27805c3b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -56,6 +56,10 @@ public class DataFrameDataExtractorFactory { return new DataFrameDataExtractor(client, context); } + public ExtractedFields getExtractedFields() { + return extractedFields; + } + private QueryBuilder createQuery() { BoolQueryBuilder query = QueryBuilders.boolQuery(); query.filter(sourceQuery); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 71058e1933e..82bbecbd358 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -379,14 +379,12 @@ public class ExtractedFieldsDetector { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); for (ExtractedField field : extractedFields.getAllFields()) { if (isBoolean(field.getTypes())) { - if (config.getAnalysis().getAllowedCategoricalTypes(field.getName()).contains(BooleanFieldMapper.CONTENT_TYPE)) { - // We convert boolean field to string if it is a categorical dependent variable - adjusted.add(ExtractedFields.applyBooleanMapping(field, Boolean.TRUE.toString(), Boolean.FALSE.toString())); - } else { - // We convert boolean fields to integers with values 0, 1 as this is the preferred - // way to consume such features in the analytics process. - adjusted.add(ExtractedFields.applyBooleanMapping(field, 1, 0)); - } + // We convert boolean fields to integers with values 0, 1 as this is the preferred + // way to consume such features in the analytics process regardless of: + // - analysis type + // - whether or not the field is categorical + // - whether or not the field is a dependent variable + adjusted.add(ExtractedFields.applyBooleanMapping(field)); } else { adjusted.add(field); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 9a172d158e5..714f6309180 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -9,11 +9,15 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.io.IOException; import java.util.Objects; import java.util.Set; +import static java.util.stream.Collectors.toMap; + public class AnalyticsProcessConfig implements ToXContentObject { private static final String JOB_ID = "job_id"; @@ -33,9 +37,10 @@ public class AnalyticsProcessConfig implements ToXContentObject { private final String resultsField; private final Set categoricalFields; private final DataFrameAnalysis analysis; + private final ExtractedFields extractedFields; public AnalyticsProcessConfig(String jobId, long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField, - Set categoricalFields, DataFrameAnalysis analysis) { + Set categoricalFields, DataFrameAnalysis analysis, ExtractedFields extractedFields) { this.jobId = Objects.requireNonNull(jobId); this.rows = rows; this.cols = cols; @@ -44,6 +49,7 @@ public class AnalyticsProcessConfig implements ToXContentObject { this.resultsField = Objects.requireNonNull(resultsField); this.categoricalFields = Objects.requireNonNull(categoricalFields); this.analysis = Objects.requireNonNull(analysis); + this.extractedFields = Objects.requireNonNull(extractedFields); } public String jobId() { @@ -68,7 +74,7 @@ public class AnalyticsProcessConfig implements ToXContentObject { builder.field(THREADS, threads); builder.field(RESULTS_FIELD, resultsField); builder.field(CATEGORICAL_FIELDS, categoricalFields); - builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis)); + builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis, extractedFields)); builder.endObject(); return builder; } @@ -76,16 +82,21 @@ public class AnalyticsProcessConfig implements ToXContentObject { private static class DataFrameAnalysisWrapper implements ToXContentObject { private final DataFrameAnalysis analysis; + private final ExtractedFields extractedFields; - private DataFrameAnalysisWrapper(DataFrameAnalysis analysis) { + private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields extractedFields) { this.analysis = analysis; + this.extractedFields = extractedFields; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("name", analysis.getWriteableName()); - builder.field("parameters", analysis.getParams()); + builder.field( + "parameters", + analysis.getParams( + extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes)))); builder.endObject(); return builder; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index ed9d715b5f7..815d8478a52 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor; import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -373,7 +374,8 @@ public class AnalyticsProcessManager { } dataExtractor = dataExtractorFactory.newExtractor(false); - AnalyticsProcessConfig analyticsProcessConfig = createProcessConfig(config, dataExtractor); + AnalyticsProcessConfig analyticsProcessConfig = + createProcessConfig(config, dataExtractor, dataExtractorFactory.getExtractedFields()); LOGGER.trace("[{}] creating analytics process with config [{}]", config.getId(), Strings.toString(analyticsProcessConfig)); // If we have no rows, that means there is no data so no point in starting the native process // just finish the task @@ -389,11 +391,20 @@ public class AnalyticsProcessManager { return true; } - private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { + private AnalyticsProcessConfig createProcessConfig( + DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor, ExtractedFields extractedFields) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); Set categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis()); - AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis()); + AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig( + config.getId(), + dataSummary.rows, + dataSummary.cols, + config.getModelMemoryLimit(), + 1, + config.getDest().getResultsField(), + categoricalFields, + config.getAnalysis(), + extractedFields); return processConfig; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java index 6740f8d4d34..a6223ec8b88 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java @@ -79,7 +79,8 @@ public class MemoryUsageEstimationProcessManager { 1, "", categoricalFields, - config.getAnalysis()); + config.getAnalysis(), + dataExtractorFactory.getExtractedFields()); AnalyticsProcess process = processFactory.createAnalyticsProcess( config, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index 9fe079b745c..347d353664d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -62,8 +62,8 @@ public class ExtractedFields { return new TimeField(name, method); } - public static ExtractedField applyBooleanMapping(ExtractedField field, T trueValue, T falseValue) { - return new BooleanMapper<>(field, trueValue, falseValue); + public static ExtractedField applyBooleanMapping(ExtractedField field) { + return new BooleanMapper<>(field, 1, 0); } public static class ExtractionMethodDetector { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index f4f25bcfa06..9c55b2a9ac9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -36,6 +36,7 @@ import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -57,7 +58,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); + assertThat(allFields, hasSize(1)); assertThat(allFields.get(0).getName(), equalTo("some_float")); assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE)); @@ -75,7 +76,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); + assertThat(allFields, hasSize(1)); assertThat(allFields.get(0).getName(), equalTo("some_number")); assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE)); @@ -121,7 +122,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(3)); + assertThat(allFields, hasSize(3)); assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()), containsInAnyOrder("some_float", "some_long", "some_boolean")); assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()), @@ -150,7 +151,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(5)); + assertThat(allFields, hasSize(5)); assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), containsInAnyOrder("foo", "some_float", "some_keyword", "some_long", "some_boolean")); assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()), @@ -223,7 +224,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); + assertThat(allFields, hasSize(1)); assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), contains("bar")); assertFieldSelectionContains(fieldExtraction.v2(), @@ -329,7 +330,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); + assertThat(allFields, hasSize(1)); assertThat(allFields.get(0).getName(), equalTo("numeric")); assertFieldSelectionContains(fieldExtraction.v2(), @@ -565,23 +566,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { contains(equalTo(ExtractedField.Method.SOURCE))); } - public void testDetect_GivenBooleanField_BooleanMappedAsInteger() { + private void testDetect_GivenBooleanField(DataFrameAnalyticsConfig config, boolean isRequired, FieldSelection.FeatureType featureType) { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_boolean", "boolean") + .addAggregatableField("some_integer", "integer") .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, config, false, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); + assertThat(allFields, hasSize(2)); ExtractedField booleanField = allFields.get(0); assertThat(booleanField.getTypes(), contains("boolean")); assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE)); - assertFieldSelectionContains(fieldExtraction.v2(), - FieldSelection.included("some_boolean", Collections.singleton("boolean"), false, FieldSelection.FeatureType.NUMERICAL) + assertFieldSelectionContains(fieldExtraction.v2().subList(0, 1), + FieldSelection.included("some_boolean", Collections.singleton("boolean"), isRequired, featureType) ); SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build(); @@ -594,34 +596,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { assertThat(booleanField.value(hit), arrayContaining(0, 1, 0)); } - public void testDetect_GivenBooleanField_BooleanMappedAsString() { - FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() - .addAggregatableField("some_boolean", "boolean") - .build(); + public void testDetect_GivenBooleanField_OutlierDetection() { + // some_boolean is a non-required, numerical feature in outlier detection analysis + testDetect_GivenBooleanField(buildOutlierDetectionConfig(), false, FieldSelection.FeatureType.NUMERICAL); + } - ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildClassificationConfig("some_boolean"), false, 100, fieldCapabilities, - Collections.singletonMap("some_boolean", 2L)); - Tuple> fieldExtraction = extractedFieldsDetector.detect(); + public void testDetect_GivenBooleanField_Regression() { + // some_boolean is a non-required, numerical feature in regression analysis + testDetect_GivenBooleanField(buildRegressionConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL); + } - List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(1)); - ExtractedField booleanField = allFields.get(0); - assertThat(booleanField.getTypes(), contains("boolean")); - assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE)); + public void testDetect_GivenBooleanField_Classification_BooleanIsFeature() { + // some_boolean is a non-required, numerical feature in classification analysis + testDetect_GivenBooleanField(buildClassificationConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL); + } - assertFieldSelectionContains(fieldExtraction.v2(), - FieldSelection.included("some_boolean", Collections.singleton("boolean"), true, FieldSelection.FeatureType.CATEGORICAL) - ); - - SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build(); - assertThat(booleanField.value(hit), arrayContaining("true")); - - hit = new SearchHitBuilder(42).addField("some_boolean", false).build(); - assertThat(booleanField.value(hit), arrayContaining("false")); - - hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build(); - assertThat(booleanField.value(hit), arrayContaining("false", "true", "false")); + public void testDetect_GivenBooleanField_Classification_BooleanIsDependentVariable() { + // some_boolean is a required, categorical dependent variable in classification analysis + testDetect_GivenBooleanField(buildClassificationConfig("some_boolean"), true, FieldSelection.FeatureType.CATEGORICAL); } public void testDetect_GivenMultiFields() { @@ -640,7 +632,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("a_float"), true, 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(5)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(5)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword")); @@ -671,7 +663,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildClassificationConfig("field_1"), true, 100, fieldCapabilities, Collections.singletonMap("field_1", 2L)); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1", "field_2")); @@ -696,7 +688,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Collections.singletonMap("field_1.keyword", 2L)); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1.keyword", "field_2")); @@ -722,7 +714,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2")); @@ -748,7 +740,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("field_2"), true, 0, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1", "field_2")); @@ -773,7 +765,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("field_2.double"), true, 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1", "field_2.double")); @@ -798,7 +790,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1", "field_2")); @@ -823,7 +815,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { SOURCE_INDEX, buildRegressionConfig("field_2"), false, 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); - assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2)); + assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, contains("field_1", "field_2")); @@ -849,7 +841,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(2)); + assertThat(allFields, hasSize(2)); assertThat(allFields.get(0).getName(), equalTo("field_11")); assertThat(allFields.get(1).getName(), equalTo("field_12")); @@ -872,7 +864,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); - assertThat(allFields.size(), equalTo(2)); + assertThat(allFields, hasSize(2)); assertThat(allFields.get(0).getName(), equalTo("field_21")); assertThat(allFields.get(1).getName(), equalTo("field_22")); @@ -914,7 +906,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase { * We assert each field individually to get useful error messages in case of failure */ private static void assertFieldSelectionContains(List actual, FieldSelection... expected) { - assertThat(actual.size(), equalTo(expected.length)); + assertThat(actual, hasSize(expected.length)); for (int i = 0; i < expected.length; i++) { assertThat("i = " + i, actual.get(i).getName(), equalTo(expected[i].getName())); assertThat("i = " + i, actual.get(i).getMappingTypes(), equalTo(expected[i].getMappingTypes())); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index df4449543d9..2668a3d1d46 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -95,6 +96,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); + when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); finishHandler = mock(Consumer.class); exceptionCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index b1a2ba226b4..15bd32da3c3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -219,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private void givenDataFrameRows(int rows) { AnalyticsProcessConfig config = new AnalyticsProcessConfig( - "job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class)); + "job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class), + mock(ExtractedFields.class)); when(process.getConfig()).thenReturn(config); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java index 5d92117b16b..97b202018b9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; @@ -70,6 +71,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase { when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); + when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID); listener = mock(ActionListener.class); resultCaptor = ArgumentCaptor.forClass(MemoryUsageEstimationResult.class); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index 9613d14fb5f..5ac983e7d50 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -101,7 +101,7 @@ public class ExtractedFieldsTests extends ESTestCase { public void testApplyBooleanMapping() { DocValueField aBool = new DocValueField("a_bool", Collections.singleton("boolean")); - ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool, 1, 0); + ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool); SearchHit hitTrue = new SearchHitBuilder(42).addField("a_bool", true).build(); SearchHit hitFalse = new SearchHitBuilder(42).addField("a_bool", false).build();