This commit is contained in:
parent
c547fabb2b
commit
d40afc7871
|
@ -143,7 +143,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
if (result.get() == null) { // These are steps 2, 3, 4 etc.
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
|
||||
// too_many_buckets_exception exception is not thrown.
|
||||
|
@ -154,7 +154,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
topActualClassNames.get().stream()
|
||||
.skip(actualClasses.size())
|
||||
.limit(actualClassesPerBatch)
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
if (keyedFiltersActual.length > 0) {
|
||||
return Tuple.tuple(
|
||||
|
|
|
@ -108,7 +108,7 @@ public class Precision implements EvaluationMetric {
|
|||
if (result.get() == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
|
||||
return Tuple.tuple(
|
||||
|
|
|
@ -29,23 +29,25 @@ import java.util.List;
|
|||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.notANumber;
|
||||
|
||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
|
||||
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
|
||||
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
|
||||
private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
|
||||
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
|
||||
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
|
||||
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
|
||||
private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
|
||||
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
|
||||
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
|
||||
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
|
||||
private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
|
@ -64,7 +66,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -78,7 +81,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
|
@ -91,163 +94,257 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
Recall.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_KeywordField() {
|
||||
public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() {
|
||||
String indexName = "some-index";
|
||||
String actualField = "fieldA";
|
||||
String predictedField = "fieldB";
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
actualField, "type=keyword",
|
||||
predictedField, "type=keyword")
|
||||
.get();
|
||||
client().prepareIndex(indexName, "_doc")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.setSource(
|
||||
actualField, "crocodile",
|
||||
predictedField, "cRoCoDiLe")
|
||||
.get();
|
||||
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
indexName,
|
||||
new Classification(
|
||||
actualField,
|
||||
predictedField,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0)));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.0));
|
||||
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
|
||||
assertThat(
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("ant", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("cat", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("dog", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("fox", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("mouse", 47.0 / 75))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new MulticlassConfusionMatrix.ActualClass(
|
||||
"crocodile", 1, Arrays.asList(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1))));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
|
||||
assertThat(precisionResult.getClasses(), empty());
|
||||
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
|
||||
assertThat(recallResult.getClasses(), contains(new Recall.PerClassResult("crocodile", 0.0)));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(0.0));
|
||||
}
|
||||
|
||||
private void evaluateAccuracy_IntegerField(String actualField) {
|
||||
private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("1", 57.0 / 75),
|
||||
new Accuracy.PerClassResult("2", 54.0 / 75),
|
||||
new Accuracy.PerClassResult("3", 51.0 / 75),
|
||||
new Accuracy.PerClassResult("4", 48.0 / 75),
|
||||
new Accuracy.PerClassResult("5", 45.0 / 75))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
||||
return accuracyResult;
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_KeywordField() {
|
||||
List<Accuracy.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("ant", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("cat", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("dog", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("fox", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("mouse", 47.0 / 75));
|
||||
double expectedOverallAccuracy = 5.0 / 75;
|
||||
|
||||
Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
|
||||
accuracyResult = evaluateAccuracy(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_IntegerField() {
|
||||
evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
List<Accuracy.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("1", 57.0 / 75),
|
||||
new Accuracy.PerClassResult("2", 54.0 / 75),
|
||||
new Accuracy.PerClassResult("3", 51.0 / 75),
|
||||
new Accuracy.PerClassResult("4", 48.0 / 75),
|
||||
new Accuracy.PerClassResult("5", 45.0 / 75));
|
||||
double expectedOverallAccuracy = 15.0 / 75;
|
||||
|
||||
public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() {
|
||||
evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
private void evaluateAccuracy_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
|
||||
accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
accuracyResult = evaluateAccuracy(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("false", 18.0 / 30),
|
||||
new Accuracy.PerClassResult("true", 27.0 / 45))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_BooleanField() {
|
||||
evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
List<Accuracy.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("false", 18.0 / 30),
|
||||
new Accuracy.PerClassResult("true", 27.0 / 45));
|
||||
double expectedOverallAccuracy = 45.0 / 75;
|
||||
|
||||
Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
|
||||
accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
accuracyResult = evaluateAccuracy(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() {
|
||||
evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
public void testEvaluate_Accuracy_FieldTypeMismatch() {
|
||||
{
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint
|
||||
List<Accuracy.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("1", 0.8),
|
||||
new Accuracy.PerClassResult("2", 0.8),
|
||||
new Accuracy.PerClassResult("3", 0.8),
|
||||
new Accuracy.PerClassResult("4", 0.8),
|
||||
new Accuracy.PerClassResult("5", 0.8));
|
||||
double expectedOverallAccuracy = 0.0;
|
||||
|
||||
Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
}
|
||||
{
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint
|
||||
List<Accuracy.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("false", 0.6),
|
||||
new Accuracy.PerClassResult("true", 0.4));
|
||||
double expectedOverallAccuracy = 0.0;
|
||||
|
||||
Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
|
||||
}
|
||||
}
|
||||
|
||||
private Precision.Result evaluatePrecision(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
return precisionResult;
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_KeywordField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
List<Precision.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("ant", 1.0 / 15),
|
||||
new Precision.PerClassResult("cat", 1.0 / 15),
|
||||
new Precision.PerClassResult("dog", 1.0 / 15),
|
||||
new Precision.PerClassResult("fox", 1.0 / 15),
|
||||
new Precision.PerClassResult("mouse", 1.0 / 15));
|
||||
double expectedAvgPrecision = 5.0 / 75;
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
|
||||
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("ant", 1.0 / 15),
|
||||
new Precision.PerClassResult("cat", 1.0 / 15),
|
||||
new Precision.PerClassResult("dog", 1.0 / 15),
|
||||
new Precision.PerClassResult("fox", 1.0 / 15),
|
||||
new Precision.PerClassResult("mouse", 1.0 / 15))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
private void evaluatePrecision_IntegerField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("1", 0.2),
|
||||
new Precision.PerClassResult("2", 0.2),
|
||||
new Precision.PerClassResult("3", 0.2),
|
||||
new Precision.PerClassResult("4", 0.2),
|
||||
new Precision.PerClassResult("5", 0.2))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(0.2));
|
||||
evaluatePrecision(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_IntegerField() {
|
||||
evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
List<Precision.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("1", 0.2),
|
||||
new Precision.PerClassResult("2", 0.2),
|
||||
new Precision.PerClassResult("3", 0.2),
|
||||
new Precision.PerClassResult("4", 0.2),
|
||||
new Precision.PerClassResult("5", 0.2));
|
||||
double expectedAvgPrecision = 0.2;
|
||||
|
||||
public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() {
|
||||
evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
|
||||
|
||||
private void evaluatePrecision_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
precisionResult = evaluatePrecision(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("false", 0.5),
|
||||
new Precision.PerClassResult("true", 9.0 / 13))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52));
|
||||
evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_BooleanField() {
|
||||
evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
List<Precision.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("false", 0.5),
|
||||
new Precision.PerClassResult("true", 9.0 / 13));
|
||||
double expectedAvgPrecision = 31.0 / 52;
|
||||
|
||||
Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
|
||||
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
precisionResult = evaluatePrecision(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
|
||||
|
||||
evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
|
||||
evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() {
|
||||
evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
public void testEvaluate_Precision_FieldTypeMismatch() {
|
||||
{
|
||||
Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here
|
||||
assertThat(precisionResult.getClasses(), empty());
|
||||
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
|
||||
}
|
||||
{
|
||||
Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here
|
||||
assertThat(precisionResult.getClasses(), empty());
|
||||
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
|
||||
}
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_CardinalityTooHigh() {
|
||||
|
@ -257,88 +354,112 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_KeywordField() {
|
||||
private Recall.Result evaluateRecall(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("ant", 1.0 / 15),
|
||||
new Recall.PerClassResult("cat", 1.0 / 15),
|
||||
new Recall.PerClassResult("dog", 1.0 / 15),
|
||||
new Recall.PerClassResult("fox", 1.0 / 15),
|
||||
new Recall.PerClassResult("mouse", 1.0 / 15))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
|
||||
return recallResult;
|
||||
}
|
||||
|
||||
private void evaluateRecall_IntegerField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall())));
|
||||
public void testEvaluate_Recall_KeywordField() {
|
||||
List<Recall.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("ant", 1.0 / 15),
|
||||
new Recall.PerClassResult("cat", 1.0 / 15),
|
||||
new Recall.PerClassResult("dog", 1.0 / 15),
|
||||
new Recall.PerClassResult("fox", 1.0 / 15),
|
||||
new Recall.PerClassResult("mouse", 1.0 / 15));
|
||||
double expectedAvgRecall = 5.0 / 75;
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("1", 1.0),
|
||||
new Recall.PerClassResult("2", 1.0),
|
||||
new Recall.PerClassResult("3", 1.0),
|
||||
new Recall.PerClassResult("4", 1.0),
|
||||
new Recall.PerClassResult("5", 1.0))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(1.0));
|
||||
evaluateRecall(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_IntegerField() {
|
||||
evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
List<Recall.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("1", 1.0 / 15),
|
||||
new Recall.PerClassResult("2", 2.0 / 15),
|
||||
new Recall.PerClassResult("3", 3.0 / 15),
|
||||
new Recall.PerClassResult("4", 4.0 / 15),
|
||||
new Recall.PerClassResult("5", 5.0 / 15));
|
||||
double expectedAvgRecall = 3.0 / 15;
|
||||
|
||||
public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() {
|
||||
evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
|
||||
private void evaluateRecall_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
recallResult = evaluateRecall(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("true", 0.6),
|
||||
new Recall.PerClassResult("false", 0.6))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(0.6));
|
||||
evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_BooleanField() {
|
||||
evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
List<Recall.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("true", 0.6),
|
||||
new Recall.PerClassResult("false", 0.6));
|
||||
double expectedAvgRecall = 0.6;
|
||||
|
||||
Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
|
||||
// Actual and predicted fields are of different types but the values are matched correctly
|
||||
recallResult = evaluateRecall(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
|
||||
evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
|
||||
evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() {
|
||||
evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
public void testEvaluate_Recall_FieldTypeMismatch() {
|
||||
{
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here
|
||||
List<Recall.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("1", 0.0),
|
||||
new Recall.PerClassResult("2", 0.0),
|
||||
new Recall.PerClassResult("3", 0.0),
|
||||
new Recall.PerClassResult("4", 0.0),
|
||||
new Recall.PerClassResult("5", 0.0));
|
||||
double expectedAvgRecall = 0.0;
|
||||
|
||||
Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
}
|
||||
{
|
||||
// When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here
|
||||
List<Recall.PerClassResult> expectedPerClassResults =
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("true", 0.0),
|
||||
new Recall.PerClassResult("false", 0.0));
|
||||
double expectedAvgRecall = 0.0;
|
||||
|
||||
Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
|
||||
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
|
||||
}
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_CardinalityTooHigh() {
|
||||
|
@ -348,16 +469,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
private void evaluateWithMulticlassConfusionMatrix() {
|
||||
private void evaluateMulticlassConfusionMatrix() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -417,16 +538,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
evaluateMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get();
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
evaluateMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get();
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
evaluateMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get();
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix);
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateMulticlassConfusionMatrix);
|
||||
|
||||
assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class)));
|
||||
TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause();
|
||||
|
@ -438,7 +559,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
|
||||
Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -476,13 +599,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
|
||||
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, "type=keyword",
|
||||
NO_LEGS_KEYWORD_FIELD, "type=keyword",
|
||||
NO_LEGS_INTEGER_FIELD, "type=integer",
|
||||
NO_LEGS_PREDICTION_FIELD, "type=integer",
|
||||
NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer",
|
||||
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
|
||||
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean")
|
||||
.get();
|
||||
}
|
||||
|
||||
|
@ -497,13 +620,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
new IndexRequest(indexName)
|
||||
.source(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
|
||||
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
|
||||
NO_LEGS_INTEGER_FIELD, i + 1,
|
||||
NO_LEGS_PREDICTION_FIELD, j + 1,
|
||||
NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1,
|
||||
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
|
||||
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
|
||||
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -519,7 +642,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
for (int i = 0; i < distinctAnimalCount; i++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
|
||||
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, randomAlphaOfLength(5)));
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
|
|
Loading…
Reference in New Issue