From 5bb668b866feb05938fdacc4a250196510f3bac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 20 Dec 2019 14:24:23 +0100 Subject: [PATCH] [7.x] Get rid of maxClassesCardinality internal parameter (#50418) (#50423) --- .../evaluation/classification/Precision.java | 22 ++++------------- .../evaluation/classification/Recall.java | 22 ++++------------- .../ClassificationEvaluationIT.java | 24 ++++++++++++++++--- 3 files changed, 31 insertions(+), 37 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index b8b468aa037..6ef2aeb1a86 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -76,25 +75,15 @@ public class Precision implements EvaluationMetric { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private String actualField; private List topActualClassNames; private EvaluationMetricResult result; - public Precision() { - this((Integer) null); - } + public Precision() {} - // Visible for testing - public Precision(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - } - - public Precision(StreamInput in) throws IOException { - this.maxClassesCardinality = in.readVInt(); - } + public Precision(StreamInput in) throws IOException {} @Override public String getWriteableName() { @@ -116,7 +105,7 @@ public class Precision implements EvaluationMetric { AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) .field(actualField) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) - .size(maxClassesCardinality)), + .size(MAX_CLASSES_CARDINALITY)), Collections.emptyList()); } if (result == null) { // This is step 2 @@ -142,7 +131,7 @@ public class Precision implements EvaluationMetric { if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { - // This means there were more than {@code maxClassesCardinality} buckets. + // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average precision accurately, so we fail. throw ExceptionsHelper.badRequestException( "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); @@ -175,7 +164,6 @@ public class Precision implements EvaluationMetric { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(maxClassesCardinality); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index c3151b82484..27ad7a8b3bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -70,24 +69,14 @@ public class Recall implements EvaluationMetric { return PARSER.apply(parser, null); } - private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + private static final int MAX_CLASSES_CARDINALITY = 1000; - private final int maxClassesCardinality; private String actualField; private EvaluationMetricResult result; - public Recall() { - this((Integer) null); - } + public Recall() {} - // Visible for testing - public Recall(@Nullable Integer maxClassesCardinality) { - this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; - } - - public Recall(StreamInput in) throws IOException { - this.maxClassesCardinality = in.readVInt(); - } + public Recall(StreamInput in) throws IOException {} @Override public String getWriteableName() { @@ -111,7 +100,7 @@ public class Recall implements EvaluationMetric { Arrays.asList( AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) .field(actualField) - .size(maxClassesCardinality) + .size(MAX_CLASSES_CARDINALITY) .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), Arrays.asList( PipelineAggregatorBuilders.avgBucket( @@ -126,7 +115,7 @@ public class Recall implements EvaluationMetric { aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); if (byActualClassAgg.getSumOfOtherDocCounts() > 0) { - // This means there were more than {@code maxClassesCardinality} buckets. + // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average recall accurately, so we fail. throw ExceptionsHelper.badRequestException( "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); @@ -149,7 +138,6 @@ public class Recall implements EvaluationMetric { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(maxClassesCardinality); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index d90609c8967..8c5987675f1 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -39,6 +39,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT @Before public void setup() { + createAnimalsIndex(ANIMALS_DATA_INDEX); indexAnimalsData(ANIMALS_DATA_INDEX); } @@ -142,12 +143,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT } public void testEvaluate_Precision_CardinalityTooHigh() { + indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4))))); + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); } @@ -174,11 +176,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT } public void testEvaluate_Recall_CardinalityTooHigh() { + indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4))))); + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); } @@ -283,7 +286,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } - private static void indexAnimalsData(String indexName) { + private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .addMapping("_doc", ANIMAL_NAME_FIELD, "type=keyword", @@ -293,7 +296,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT IS_PREDATOR_FIELD, "type=boolean", IS_PREDATOR_PREDICTION_FIELD, "type=boolean") .get(); + } + private static void indexAnimalsData(String indexName) { List animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -317,4 +322,17 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT fail("Failed to index data: " + bulkResponse.buildFailureMessage()); } } + + private static void indexDistinctAnimals(String indexName, int distinctAnimalCount) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < distinctAnimalCount; i++) { + bulkRequestBuilder.add( + new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } }