This commit is contained in:
parent
049d854360
commit
0965a10468
|
@ -67,6 +67,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
.flatMap(Set::stream)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
/**
|
||||
* Name of the parameter passed down to C++.
|
||||
* This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction.
|
||||
*/
|
||||
private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
|
||||
|
||||
/**
|
||||
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
|
||||
* This way the user can see if the prediction was made with confidence they need.
|
||||
|
@ -154,7 +160,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams() {
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
|
@ -162,9 +168,30 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
|
||||
if (predictionFieldType != null) {
|
||||
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
|
||||
if (dependentVariableTypes == null) {
|
||||
return null;
|
||||
}
|
||||
if (Types.categorical().containsAll(dependentVariableTypes)) {
|
||||
return "string";
|
||||
}
|
||||
if (Types.bool().containsAll(dependentVariableTypes)) {
|
||||
return "bool";
|
||||
}
|
||||
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
|
||||
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
|
||||
return "int";
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsCategoricalFields() {
|
||||
return true;
|
||||
|
|
|
@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
|
||||
/**
|
||||
* @return The analysis parameters as a map
|
||||
* @param extractedFields map of (name, types) for all the extracted fields
|
||||
*/
|
||||
Map<String, Object> getParams();
|
||||
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
|
||||
|
||||
/**
|
||||
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
|
||||
|
|
|
@ -192,7 +192,7 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams() {
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
if (nNeighbors != null) {
|
||||
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);
|
||||
|
|
|
@ -124,7 +124,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> getParams() {
|
||||
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||
params.putAll(boostedTreeParams.getParams());
|
||||
|
|
|
@ -8,11 +8,20 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
|||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
@ -115,6 +124,34 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testGetParams() {
|
||||
Map<String, Set<String>> extractedFields = new HashMap<>(3);
|
||||
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
|
||||
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
|
||||
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
|
||||
assertThat(
|
||||
new Classification("foo").getParams(extractedFields),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "foo"),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "foo_prediction"),
|
||||
hasEntry("prediction_field_type", "bool")));
|
||||
assertThat(
|
||||
new Classification("bar").getParams(extractedFields),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "bar"),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "bar_prediction"),
|
||||
hasEntry("prediction_field_type", "int")));
|
||||
assertThat(
|
||||
new Classification("baz").getParams(extractedFields),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "baz"),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "baz_prediction"),
|
||||
hasEntry("prediction_field_type", "string")));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
|
||||
public void testGetParams_GivenDefaults() {
|
||||
OutlierDetection outlierDetection = new OutlierDetection.Builder().build();
|
||||
Map<String, Object> params = outlierDetection.getParams();
|
||||
Map<String, Object> params = outlierDetection.getParams(null);
|
||||
assertThat(params.size(), equalTo(3));
|
||||
assertThat(params.containsKey("compute_feature_influence"), is(true));
|
||||
assertThat(params.get("compute_feature_influence"), is(true));
|
||||
|
@ -71,7 +71,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
.setStandardizationEnabled(false)
|
||||
.build();
|
||||
|
||||
Map<String, Object> params = outlierDetection.getParams();
|
||||
Map<String, Object> params = outlierDetection.getParams(null);
|
||||
|
||||
assertThat(params.size(), equalTo(6));
|
||||
assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42));
|
||||
|
|
|
@ -12,7 +12,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
@ -83,6 +85,12 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testGetParams() {
|
||||
assertThat(
|
||||
new Regression("foo").getParams(null),
|
||||
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction")));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
|
||||
public class TypesTests extends ESTestCase {
|
||||
|
||||
public void testTypes() {
|
||||
assertThat(Sets.intersection(Types.bool(), Types.categorical()), empty());
|
||||
assertThat(Sets.intersection(Types.categorical(), Types.numerical()), empty());
|
||||
assertThat(Sets.intersection(Types.numerical(), Types.bool()), empty());
|
||||
assertThat(Sets.difference(Types.discreteNumerical(), Types.numerical()), empty());
|
||||
}
|
||||
}
|
|
@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
|
||||
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
|
||||
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
|
||||
private static final String ANIMAL_NAME_FIELD = "animal_name";
|
||||
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
|
||||
private static final String NO_LEGS_FIELD = "no_legs";
|
||||
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
|
||||
private static final String IS_PREDATOR_FIELD = "predator";
|
||||
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
|
@ -41,9 +45,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
cleanUp();
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -52,10 +56,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_Accuracy() {
|
||||
public void testEvaluate_Accuracy_KeywordField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -74,11 +78,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
|
||||
public void testEvaluate_Accuracy_IntegerField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(Arrays.asList(
|
||||
new Accuracy.ActualClass("1", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("2", 15, 2.0 / 15),
|
||||
new Accuracy.ActualClass("3", 15, 3.0 / 15),
|
||||
new Accuracy.ActualClass("4", 15, 4.0 / 15),
|
||||
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_BooleanField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(Arrays.asList(
|
||||
new Accuracy.ActualClass("true", 45, 27.0 / 45),
|
||||
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -137,11 +180,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -168,20 +211,30 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
|
||||
.addMapping("_doc",
|
||||
ANIMAL_NAME_FIELD, "type=keyword",
|
||||
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
|
||||
NO_LEGS_FIELD, "type=integer",
|
||||
NO_LEGS_PREDICTION_FIELD, "type=integer",
|
||||
IS_PREDATOR_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
|
||||
.get();
|
||||
|
||||
List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
|
||||
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < classNames.size(); i++) {
|
||||
for (int j = 0; j < classNames.size(); j++) {
|
||||
for (int i = 0; i < animalNames.size(); i++) {
|
||||
for (int j = 0; j < animalNames.size(); j++) {
|
||||
for (int k = 0; k < j + 1; k++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(
|
||||
ACTUAL_CLASS_FIELD, classNames.get(i),
|
||||
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
|
||||
ANIMAL_NAME_FIELD, animalNames.get(i),
|
||||
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
NO_LEGS_FIELD, i + 1,
|
||||
NO_LEGS_PREDICTION_FIELD, j + 1,
|
||||
IS_PREDATOR_FIELD, i % 2 == 0,
|
||||
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,9 +20,8 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -30,7 +29,6 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
|
@ -88,7 +86,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -102,7 +100,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
|
@ -128,7 +126,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -142,11 +140,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
|
||||
}
|
||||
|
||||
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
|
||||
String jobId, String dependentVariable, List<T> dependentVariableValues) throws Exception {
|
||||
initialize(jobId);
|
||||
String predictedClassField = dependentVariable + "_prediction";
|
||||
indexData(sourceIndex, 300, 0, dependentVariable);
|
||||
|
@ -175,9 +173,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
|
||||
@SuppressWarnings("unchecked")
|
||||
T predictedClassValue = (T) resultsObject.get(predictedClassField);
|
||||
assertThat(predictedClassValue, is(in(dependentVariableValues)));
|
||||
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);
|
||||
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
|
@ -201,33 +200,32 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Creating destination index [" + destIndex + "]",
|
||||
"Finished reindexing to destination index [" + destIndex + "]",
|
||||
"Finished analysis");
|
||||
assertEvaluation(
|
||||
dependentVariable,
|
||||
dependentVariableValues.stream().map(String::valueOf).collect(toList()),
|
||||
"ml." + predictedClassField);
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
|
||||
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES);
|
||||
assertEvaluation(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "ml.discrete-numerical-field_prediction");
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
|
||||
ElasticsearchStatusException e = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
|
||||
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES));
|
||||
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
|
||||
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
|
||||
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES);
|
||||
assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction");
|
||||
}
|
||||
|
||||
public void testDependentVariableCardinalityTooHighError() {
|
||||
|
@ -317,25 +315,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
return resultsObject;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static <T> void assertTopClasses(
|
||||
Map<String, Object> resultsObject,
|
||||
int numTopClasses,
|
||||
String dependentVariable,
|
||||
List<T> dependentVariableValues,
|
||||
Function<String, T> parser) {
|
||||
List<T> dependentVariableValues) {
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
|
||||
assertThat(topClasses, hasSize(numTopClasses));
|
||||
List<String> classNames = new ArrayList<>(topClasses.size());
|
||||
List<T> classNames = new ArrayList<>(topClasses.size());
|
||||
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
|
||||
for (Map<String, Object> topClass : topClasses) {
|
||||
assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability")));
|
||||
classNames.add((String) topClass.get("class_name"));
|
||||
classNames.add((T) topClass.get("class_name"));
|
||||
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||
}
|
||||
// Assert that all the predicted class names come from the set of dependent variable values.
|
||||
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
|
||||
classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues))));
|
||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
|
||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||
|
@ -344,25 +341,44 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
|
||||
}
|
||||
|
||||
private void assertEvaluation(String dependentVariable, List<String> dependentVariableValues, String predictedClassField) {
|
||||
private <T> void assertEvaluation(String dependentVariable, List<T> dependentVariableValues, String predictedClassField) {
|
||||
List<String> dependentVariableValuesAsStrings = dependentVariableValues.stream().map(String::valueOf).collect(toList());
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
destIndex,
|
||||
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
|
||||
dependentVariable, predictedClassField, null));
|
||||
dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix())));
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
List<ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
|
||||
assertThat(actualClasses.stream().map(ActualClass::getActualClass).collect(toList()), equalTo(dependentVariableValues));
|
||||
for (ActualClass actualClass : actualClasses) {
|
||||
assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
||||
|
||||
{ // Accuracy
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
|
||||
assertThat(
|
||||
actualClass.getPredictedClasses().stream().map(PredictedClass::getPredictedClass).collect(toList()),
|
||||
equalTo(dependentVariableValues));
|
||||
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
|
||||
equalTo(dependentVariableValuesAsStrings));
|
||||
actualClasses.forEach(
|
||||
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
}
|
||||
|
||||
{ // MulticlassConfusionMatrix
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
List<MulticlassConfusionMatrix.ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
|
||||
assertThat(
|
||||
actualClasses.stream().map(MulticlassConfusionMatrix.ActualClass::getActualClass).collect(toList()),
|
||||
equalTo(dependentVariableValuesAsStrings));
|
||||
for (MulticlassConfusionMatrix.ActualClass actualClass : actualClasses) {
|
||||
assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L));
|
||||
assertThat(
|
||||
actualClass.getPredictedClasses().stream()
|
||||
.map(MulticlassConfusionMatrix.PredictedClass::getPredictedClass)
|
||||
.collect(toList()),
|
||||
equalTo(dependentVariableValuesAsStrings));
|
||||
}
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,6 +56,10 @@ public class DataFrameDataExtractorFactory {
|
|||
return new DataFrameDataExtractor(client, context);
|
||||
}
|
||||
|
||||
public ExtractedFields getExtractedFields() {
|
||||
return extractedFields;
|
||||
}
|
||||
|
||||
private QueryBuilder createQuery() {
|
||||
BoolQueryBuilder query = QueryBuilders.boolQuery();
|
||||
query.filter(sourceQuery);
|
||||
|
|
|
@ -379,14 +379,12 @@ public class ExtractedFieldsDetector {
|
|||
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
|
||||
for (ExtractedField field : extractedFields.getAllFields()) {
|
||||
if (isBoolean(field.getTypes())) {
|
||||
if (config.getAnalysis().getAllowedCategoricalTypes(field.getName()).contains(BooleanFieldMapper.CONTENT_TYPE)) {
|
||||
// We convert boolean field to string if it is a categorical dependent variable
|
||||
adjusted.add(ExtractedFields.applyBooleanMapping(field, Boolean.TRUE.toString(), Boolean.FALSE.toString()));
|
||||
} else {
|
||||
// We convert boolean fields to integers with values 0, 1 as this is the preferred
|
||||
// way to consume such features in the analytics process.
|
||||
adjusted.add(ExtractedFields.applyBooleanMapping(field, 1, 0));
|
||||
}
|
||||
// We convert boolean fields to integers with values 0, 1 as this is the preferred
|
||||
// way to consume such features in the analytics process regardless of:
|
||||
// - analysis type
|
||||
// - whether or not the field is categorical
|
||||
// - whether or not the field is a dependent variable
|
||||
adjusted.add(ExtractedFields.applyBooleanMapping(field));
|
||||
} else {
|
||||
adjusted.add(field);
|
||||
}
|
||||
|
|
|
@ -9,11 +9,15 @@ import org.elasticsearch.common.unit.ByteSizeValue;
|
|||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.stream.Collectors.toMap;
|
||||
|
||||
public class AnalyticsProcessConfig implements ToXContentObject {
|
||||
|
||||
private static final String JOB_ID = "job_id";
|
||||
|
@ -33,9 +37,10 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
private final String resultsField;
|
||||
private final Set<String> categoricalFields;
|
||||
private final DataFrameAnalysis analysis;
|
||||
private final ExtractedFields extractedFields;
|
||||
|
||||
public AnalyticsProcessConfig(String jobId, long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField,
|
||||
Set<String> categoricalFields, DataFrameAnalysis analysis) {
|
||||
Set<String> categoricalFields, DataFrameAnalysis analysis, ExtractedFields extractedFields) {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
this.rows = rows;
|
||||
this.cols = cols;
|
||||
|
@ -44,6 +49,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
this.resultsField = Objects.requireNonNull(resultsField);
|
||||
this.categoricalFields = Objects.requireNonNull(categoricalFields);
|
||||
this.analysis = Objects.requireNonNull(analysis);
|
||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||
}
|
||||
|
||||
public String jobId() {
|
||||
|
@ -68,7 +74,7 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
builder.field(THREADS, threads);
|
||||
builder.field(RESULTS_FIELD, resultsField);
|
||||
builder.field(CATEGORICAL_FIELDS, categoricalFields);
|
||||
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis));
|
||||
builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis, extractedFields));
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -76,16 +82,21 @@ public class AnalyticsProcessConfig implements ToXContentObject {
|
|||
private static class DataFrameAnalysisWrapper implements ToXContentObject {
|
||||
|
||||
private final DataFrameAnalysis analysis;
|
||||
private final ExtractedFields extractedFields;
|
||||
|
||||
private DataFrameAnalysisWrapper(DataFrameAnalysis analysis) {
|
||||
private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields extractedFields) {
|
||||
this.analysis = analysis;
|
||||
this.extractedFields = extractedFields;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field("name", analysis.getWriteableName());
|
||||
builder.field("parameters", analysis.getParams());
|
||||
builder.field(
|
||||
"parameters",
|
||||
analysis.getParams(
|
||||
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
|
@ -373,7 +374,8 @@ public class AnalyticsProcessManager {
|
|||
}
|
||||
|
||||
dataExtractor = dataExtractorFactory.newExtractor(false);
|
||||
AnalyticsProcessConfig analyticsProcessConfig = createProcessConfig(config, dataExtractor);
|
||||
AnalyticsProcessConfig analyticsProcessConfig =
|
||||
createProcessConfig(config, dataExtractor, dataExtractorFactory.getExtractedFields());
|
||||
LOGGER.trace("[{}] creating analytics process with config [{}]", config.getId(), Strings.toString(analyticsProcessConfig));
|
||||
// If we have no rows, that means there is no data so no point in starting the native process
|
||||
// just finish the task
|
||||
|
@ -389,11 +391,20 @@ public class AnalyticsProcessManager {
|
|||
return true;
|
||||
}
|
||||
|
||||
private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) {
|
||||
private AnalyticsProcessConfig createProcessConfig(
|
||||
DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor, ExtractedFields extractedFields) {
|
||||
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
|
||||
Set<String> categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis());
|
||||
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols,
|
||||
config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis());
|
||||
AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(
|
||||
config.getId(),
|
||||
dataSummary.rows,
|
||||
dataSummary.cols,
|
||||
config.getModelMemoryLimit(),
|
||||
1,
|
||||
config.getDest().getResultsField(),
|
||||
categoricalFields,
|
||||
config.getAnalysis(),
|
||||
extractedFields);
|
||||
return processConfig;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,8 @@ public class MemoryUsageEstimationProcessManager {
|
|||
1,
|
||||
"",
|
||||
categoricalFields,
|
||||
config.getAnalysis());
|
||||
config.getAnalysis(),
|
||||
dataExtractorFactory.getExtractedFields());
|
||||
AnalyticsProcess<MemoryUsageEstimationResult> process =
|
||||
processFactory.createAnalyticsProcess(
|
||||
config,
|
||||
|
|
|
@ -62,8 +62,8 @@ public class ExtractedFields {
|
|||
return new TimeField(name, method);
|
||||
}
|
||||
|
||||
public static <T> ExtractedField applyBooleanMapping(ExtractedField field, T trueValue, T falseValue) {
|
||||
return new BooleanMapper<>(field, trueValue, falseValue);
|
||||
public static ExtractedField applyBooleanMapping(ExtractedField field) {
|
||||
return new BooleanMapper<>(field, 1, 0);
|
||||
}
|
||||
|
||||
public static class ExtractionMethodDetector {
|
||||
|
|
|
@ -36,6 +36,7 @@ import static org.hamcrest.Matchers.arrayContaining;
|
|||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -57,7 +58,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
assertThat(allFields, hasSize(1));
|
||||
assertThat(allFields.get(0).getName(), equalTo("some_float"));
|
||||
assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
|
||||
|
||||
|
@ -75,7 +76,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
assertThat(allFields, hasSize(1));
|
||||
assertThat(allFields.get(0).getName(), equalTo("some_number"));
|
||||
assertThat(allFields.get(0).getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
|
||||
|
||||
|
@ -121,7 +122,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(3));
|
||||
assertThat(allFields, hasSize(3));
|
||||
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()),
|
||||
containsInAnyOrder("some_float", "some_long", "some_boolean"));
|
||||
assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()),
|
||||
|
@ -150,7 +151,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(5));
|
||||
assertThat(allFields, hasSize(5));
|
||||
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
|
||||
containsInAnyOrder("foo", "some_float", "some_keyword", "some_long", "some_boolean"));
|
||||
assertThat(allFields.stream().map(ExtractedField::getMethod).collect(Collectors.toSet()),
|
||||
|
@ -223,7 +224,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
assertThat(allFields, hasSize(1));
|
||||
assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()), contains("bar"));
|
||||
|
||||
assertFieldSelectionContains(fieldExtraction.v2(),
|
||||
|
@ -329,7 +330,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
assertThat(allFields, hasSize(1));
|
||||
assertThat(allFields.get(0).getName(), equalTo("numeric"));
|
||||
|
||||
assertFieldSelectionContains(fieldExtraction.v2(),
|
||||
|
@ -565,23 +566,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
contains(equalTo(ExtractedField.Method.SOURCE)));
|
||||
}
|
||||
|
||||
public void testDetect_GivenBooleanField_BooleanMappedAsInteger() {
|
||||
private void testDetect_GivenBooleanField(DataFrameAnalyticsConfig config, boolean isRequired, FieldSelection.FeatureType featureType) {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_boolean", "boolean")
|
||||
.addAggregatableField("some_integer", "integer")
|
||||
.build();
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap());
|
||||
SOURCE_INDEX, config, false, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
assertThat(allFields, hasSize(2));
|
||||
ExtractedField booleanField = allFields.get(0);
|
||||
assertThat(booleanField.getTypes(), contains("boolean"));
|
||||
assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
|
||||
|
||||
assertFieldSelectionContains(fieldExtraction.v2(),
|
||||
FieldSelection.included("some_boolean", Collections.singleton("boolean"), false, FieldSelection.FeatureType.NUMERICAL)
|
||||
assertFieldSelectionContains(fieldExtraction.v2().subList(0, 1),
|
||||
FieldSelection.included("some_boolean", Collections.singleton("boolean"), isRequired, featureType)
|
||||
);
|
||||
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
|
||||
|
@ -594,34 +596,24 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
assertThat(booleanField.value(hit), arrayContaining(0, 1, 0));
|
||||
}
|
||||
|
||||
public void testDetect_GivenBooleanField_BooleanMappedAsString() {
|
||||
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
|
||||
.addAggregatableField("some_boolean", "boolean")
|
||||
.build();
|
||||
public void testDetect_GivenBooleanField_OutlierDetection() {
|
||||
// some_boolean is a non-required, numerical feature in outlier detection analysis
|
||||
testDetect_GivenBooleanField(buildOutlierDetectionConfig(), false, FieldSelection.FeatureType.NUMERICAL);
|
||||
}
|
||||
|
||||
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
|
||||
SOURCE_INDEX, buildClassificationConfig("some_boolean"), false, 100, fieldCapabilities,
|
||||
Collections.singletonMap("some_boolean", 2L));
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
public void testDetect_GivenBooleanField_Regression() {
|
||||
// some_boolean is a non-required, numerical feature in regression analysis
|
||||
testDetect_GivenBooleanField(buildRegressionConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL);
|
||||
}
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(1));
|
||||
ExtractedField booleanField = allFields.get(0);
|
||||
assertThat(booleanField.getTypes(), contains("boolean"));
|
||||
assertThat(booleanField.getMethod(), equalTo(ExtractedField.Method.DOC_VALUE));
|
||||
public void testDetect_GivenBooleanField_Classification_BooleanIsFeature() {
|
||||
// some_boolean is a non-required, numerical feature in classification analysis
|
||||
testDetect_GivenBooleanField(buildClassificationConfig("some_integer"), false, FieldSelection.FeatureType.NUMERICAL);
|
||||
}
|
||||
|
||||
assertFieldSelectionContains(fieldExtraction.v2(),
|
||||
FieldSelection.included("some_boolean", Collections.singleton("boolean"), true, FieldSelection.FeatureType.CATEGORICAL)
|
||||
);
|
||||
|
||||
SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("true"));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("false"));
|
||||
|
||||
hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
|
||||
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
|
||||
public void testDetect_GivenBooleanField_Classification_BooleanIsDependentVariable() {
|
||||
// some_boolean is a required, categorical dependent variable in classification analysis
|
||||
testDetect_GivenBooleanField(buildClassificationConfig("some_boolean"), true, FieldSelection.FeatureType.CATEGORICAL);
|
||||
}
|
||||
|
||||
public void testDetect_GivenMultiFields() {
|
||||
|
@ -640,7 +632,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("a_float"), true, 100, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(5));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(5));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
|
||||
|
@ -671,7 +663,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildClassificationConfig("field_1"), true, 100, fieldCapabilities, Collections.singletonMap("field_1", 2L));
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1", "field_2"));
|
||||
|
@ -696,7 +688,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Collections.singletonMap("field_1.keyword", 2L));
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
|
||||
|
@ -722,7 +714,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
|
||||
|
@ -748,7 +740,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 0, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1", "field_2"));
|
||||
|
@ -773,7 +765,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("field_2.double"), true, 100, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
|
||||
|
@ -798,7 +790,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1", "field_2"));
|
||||
|
@ -823,7 +815,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
SOURCE_INDEX, buildRegressionConfig("field_2"), false, 100, fieldCapabilities, Collections.emptyMap());
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
assertThat(fieldExtraction.v1().getAllFields().size(), equalTo(2));
|
||||
assertThat(fieldExtraction.v1().getAllFields(), hasSize(2));
|
||||
List<String> extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName)
|
||||
.collect(Collectors.toList());
|
||||
assertThat(extractedFieldNames, contains("field_1", "field_2"));
|
||||
|
@ -849,7 +841,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(2));
|
||||
assertThat(allFields, hasSize(2));
|
||||
assertThat(allFields.get(0).getName(), equalTo("field_11"));
|
||||
assertThat(allFields.get(1).getName(), equalTo("field_12"));
|
||||
|
||||
|
@ -872,7 +864,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();
|
||||
assertThat(allFields.size(), equalTo(2));
|
||||
assertThat(allFields, hasSize(2));
|
||||
assertThat(allFields.get(0).getName(), equalTo("field_21"));
|
||||
assertThat(allFields.get(1).getName(), equalTo("field_22"));
|
||||
|
||||
|
@ -914,7 +906,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
* We assert each field individually to get useful error messages in case of failure
|
||||
*/
|
||||
private static void assertFieldSelectionContains(List<FieldSelection> actual, FieldSelection... expected) {
|
||||
assertThat(actual.size(), equalTo(expected.length));
|
||||
assertThat(actual, hasSize(expected.length));
|
||||
for (int i = 0; i < expected.length; i++) {
|
||||
assertThat("i = " + i, actual.get(i).getName(), equalTo(expected[i].getName()));
|
||||
assertThat("i = " + i, actual.get(i).getMappingTypes(), equalTo(expected[i].getMappingTypes()));
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
|
|||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
|
@ -95,6 +96,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
||||
finishHandler = mock(Consumer.class);
|
||||
|
||||
exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
|||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
|
@ -219,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
|
||||
private void givenDataFrameRows(int rows) {
|
||||
AnalyticsProcessConfig config = new AnalyticsProcessConfig(
|
||||
"job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class));
|
||||
"job_id", rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class),
|
||||
mock(ExtractedFields.class));
|
||||
when(process.getConfig()).thenReturn(config);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InOrder;
|
||||
|
@ -70,6 +71,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
|
|||
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
||||
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
|
||||
listener = mock(ActionListener.class);
|
||||
resultCaptor = ArgumentCaptor.forClass(MemoryUsageEstimationResult.class);
|
||||
|
|
|
@ -101,7 +101,7 @@ public class ExtractedFieldsTests extends ESTestCase {
|
|||
public void testApplyBooleanMapping() {
|
||||
DocValueField aBool = new DocValueField("a_bool", Collections.singleton("boolean"));
|
||||
|
||||
ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool, 1, 0);
|
||||
ExtractedField mapped = ExtractedFields.applyBooleanMapping(aBool);
|
||||
|
||||
SearchHit hitTrue = new SearchHitBuilder(42).addField("a_bool", true).build();
|
||||
SearchHit hitFalse = new SearchHitBuilder(42).addField("a_bool", false).build();
|
||||
|
|
Loading…
Reference in New Issue