[7.x] Get rid of maxClassesCardinality internal parameter (#50418) (#50423)

This commit is contained in:
Przemysław Witek 2019-12-20 14:24:23 +01:00 committed by GitHub
parent 541dc262bb
commit 5bb668b866
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 37 deletions

View File

@ -5,7 +5,6 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -76,25 +75,15 @@ public class Precision implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; private static final int MAX_CLASSES_CARDINALITY = 1000;
private final int maxClassesCardinality;
private String actualField; private String actualField;
private List<String> topActualClassNames; private List<String> topActualClassNames;
private EvaluationMetricResult result; private EvaluationMetricResult result;
public Precision() { public Precision() {}
this((Integer) null);
}
// Visible for testing public Precision(StreamInput in) throws IOException {}
public Precision(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Precision(StreamInput in) throws IOException {
this.maxClassesCardinality = in.readVInt();
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
@ -116,7 +105,7 @@ public class Precision implements EvaluationMetric {
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
.field(actualField) .field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(maxClassesCardinality)), .size(MAX_CLASSES_CARDINALITY)),
Collections.emptyList()); Collections.emptyList());
} }
if (result == null) { // This is step 2 if (result == null) { // This is step 2
@ -142,7 +131,7 @@ public class Precision implements EvaluationMetric {
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code maxClassesCardinality} buckets. // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average precision accurately, so we fail. // We cannot calculate average precision accurately, so we fail.
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
@ -175,7 +164,6 @@ public class Precision implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(maxClassesCardinality);
} }
@Override @Override

View File

@ -5,7 +5,6 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -70,24 +69,14 @@ public class Recall implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; private static final int MAX_CLASSES_CARDINALITY = 1000;
private final int maxClassesCardinality;
private String actualField; private String actualField;
private EvaluationMetricResult result; private EvaluationMetricResult result;
public Recall() { public Recall() {}
this((Integer) null);
}
// Visible for testing public Recall(StreamInput in) throws IOException {}
public Recall(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Recall(StreamInput in) throws IOException {
this.maxClassesCardinality = in.readVInt();
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
@ -111,7 +100,7 @@ public class Recall implements EvaluationMetric {
Arrays.asList( Arrays.asList(
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField) .field(actualField)
.size(maxClassesCardinality) .size(MAX_CLASSES_CARDINALITY)
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
Arrays.asList( Arrays.asList(
PipelineAggregatorBuilders.avgBucket( PipelineAggregatorBuilders.avgBucket(
@ -126,7 +115,7 @@ public class Recall implements EvaluationMetric {
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
if (byActualClassAgg.getSumOfOtherDocCounts() > 0) { if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code maxClassesCardinality} buckets. // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average recall accurately, so we fail. // We cannot calculate average recall accurately, so we fail.
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
@ -149,7 +138,6 @@ public class Recall implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(maxClassesCardinality);
} }
@Override @Override

View File

@ -39,6 +39,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
@Before @Before
public void setup() { public void setup() {
createAnimalsIndex(ANIMALS_DATA_INDEX);
indexAnimalsData(ANIMALS_DATA_INDEX); indexAnimalsData(ANIMALS_DATA_INDEX);
} }
@ -142,12 +143,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
} }
public void testEvaluate_Precision_CardinalityTooHigh() { public void testEvaluate_Precision_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e = ElasticsearchStatusException e =
expectThrows( expectThrows(
ElasticsearchStatusException.class, ElasticsearchStatusException.class,
() -> evaluateDataFrame( () -> evaluateDataFrame(
ANIMALS_DATA_INDEX, ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4))))); new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
} }
@ -174,11 +176,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
} }
public void testEvaluate_Recall_CardinalityTooHigh() { public void testEvaluate_Recall_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e = ElasticsearchStatusException e =
expectThrows( expectThrows(
ElasticsearchStatusException.class, ElasticsearchStatusException.class,
() -> evaluateDataFrame( () -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4))))); ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
} }
@ -283,7 +286,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
} }
private static void indexAnimalsData(String indexName) { private static void createAnimalsIndex(String indexName) {
client().admin().indices().prepareCreate(indexName) client().admin().indices().prepareCreate(indexName)
.addMapping("_doc", .addMapping("_doc",
ANIMAL_NAME_FIELD, "type=keyword", ANIMAL_NAME_FIELD, "type=keyword",
@ -293,7 +296,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
IS_PREDATOR_FIELD, "type=boolean", IS_PREDATOR_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean") IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
.get(); .get();
}
private static void indexAnimalsData(String indexName) {
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
@ -317,4 +322,17 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
fail("Failed to index data: " + bulkResponse.buildFailureMessage()); fail("Failed to index data: " + bulkResponse.buildFailureMessage());
} }
} }
private static void indexDistinctAnimals(String indexName, int distinctAnimalCount) {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < distinctAnimalCount; i++) {
bulkRequestBuilder.add(
new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
} }