This commit is contained in:
parent
541dc262bb
commit
5bb668b866
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -76,25 +75,15 @@ public class Precision implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
|
||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private final int maxClassesCardinality;
|
||||
private String actualField;
|
||||
private List<String> topActualClassNames;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public Precision() {
|
||||
this((Integer) null);
|
||||
}
|
||||
public Precision() {}
|
||||
|
||||
// Visible for testing
|
||||
public Precision(@Nullable Integer maxClassesCardinality) {
|
||||
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
|
||||
}
|
||||
|
||||
public Precision(StreamInput in) throws IOException {
|
||||
this.maxClassesCardinality = in.readVInt();
|
||||
}
|
||||
public Precision(StreamInput in) throws IOException {}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
|
@ -116,7 +105,7 @@ public class Precision implements EvaluationMetric {
|
|||
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(maxClassesCardinality)),
|
||||
.size(MAX_CLASSES_CARDINALITY)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
|
@ -142,7 +131,7 @@ public class Precision implements EvaluationMetric {
|
|||
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
||||
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
|
||||
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
|
||||
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||
// We cannot calculate average precision accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
|
||||
|
@ -175,7 +164,6 @@ public class Precision implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(maxClassesCardinality);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -70,24 +69,14 @@ public class Recall implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
|
||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private final int maxClassesCardinality;
|
||||
private String actualField;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public Recall() {
|
||||
this((Integer) null);
|
||||
}
|
||||
public Recall() {}
|
||||
|
||||
// Visible for testing
|
||||
public Recall(@Nullable Integer maxClassesCardinality) {
|
||||
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
|
||||
}
|
||||
|
||||
public Recall(StreamInput in) throws IOException {
|
||||
this.maxClassesCardinality = in.readVInt();
|
||||
}
|
||||
public Recall(StreamInput in) throws IOException {}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
|
@ -111,7 +100,7 @@ public class Recall implements EvaluationMetric {
|
|||
Arrays.asList(
|
||||
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
|
||||
.field(actualField)
|
||||
.size(maxClassesCardinality)
|
||||
.size(MAX_CLASSES_CARDINALITY)
|
||||
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
|
||||
Arrays.asList(
|
||||
PipelineAggregatorBuilders.avgBucket(
|
||||
|
@ -126,7 +115,7 @@ public class Recall implements EvaluationMetric {
|
|||
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
|
||||
if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
|
||||
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||
// We cannot calculate average recall accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
|
||||
|
@ -149,7 +138,6 @@ public class Recall implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(maxClassesCardinality);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -39,6 +39,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
@Before
|
||||
public void setup() {
|
||||
createAnimalsIndex(ANIMALS_DATA_INDEX);
|
||||
indexAnimalsData(ANIMALS_DATA_INDEX);
|
||||
}
|
||||
|
||||
|
@ -142,12 +143,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
}
|
||||
|
||||
public void testEvaluate_Precision_CardinalityTooHigh() {
|
||||
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4)))));
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
}
|
||||
|
||||
|
@ -174,11 +176,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
}
|
||||
|
||||
public void testEvaluate_Recall_CardinalityTooHigh() {
|
||||
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4)))));
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
}
|
||||
|
||||
|
@ -283,7 +286,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
|
||||
}
|
||||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
private static void createAnimalsIndex(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
ANIMAL_NAME_FIELD, "type=keyword",
|
||||
|
@ -293,7 +296,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
IS_PREDATOR_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
|
||||
.get();
|
||||
}
|
||||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
@ -317,4 +322,17 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static void indexDistinctAnimals(String indexName, int distinctAnimalCount) {
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < distinctAnimalCount; i++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue