This commit is contained in:
parent
f4530580e7
commit
bd761cce1d
|
@ -114,27 +114,23 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
}
|
||||
|
||||
private static final ParseField SCORE = new ParseField("score");
|
||||
private static final ParseField DOC_COUNT = new ParseField("doc_count");
|
||||
private static final ParseField CURVE = new ParseField("curve");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
|
||||
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), SCORE);
|
||||
PARSER.declareLong(constructorArg(), DOC_COUNT);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
|
||||
}
|
||||
|
||||
private final double score;
|
||||
private final long docCount;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
|
||||
public Result(double score, @Nullable List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.docCount = docCount;
|
||||
this.curve = curve;
|
||||
}
|
||||
|
||||
|
@ -147,10 +143,6 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
return score;
|
||||
}
|
||||
|
||||
public long getDocCount() {
|
||||
return docCount;
|
||||
}
|
||||
|
||||
public List<AucRocPoint> getCurve() {
|
||||
return curve == null ? null : Collections.unmodifiableList(curve);
|
||||
}
|
||||
|
@ -159,7 +151,6 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SCORE.getPreferredName(), score);
|
||||
builder.field(DOC_COUNT.getPreferredName(), docCount);
|
||||
if (curve != null && curve.isEmpty() == false) {
|
||||
builder.field(CURVE.getPreferredName(), curve);
|
||||
}
|
||||
|
@ -173,13 +164,12 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return score == that.score
|
||||
&& docCount == that.docCount
|
||||
&& Objects.equals(curve, that.curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(score, docCount, curve);
|
||||
return Objects.hash(score, curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -200,6 +200,7 @@ import java.util.Locale;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
|
@ -1931,18 +1932,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
createIndex(indexName, mappingForClassification());
|
||||
BulkRequest regressionBulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.9))
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.85))
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.95))
|
||||
.add(docForClassification(indexName, "cat", "dog", 0.4))
|
||||
.add(docForClassification(indexName, "cat", "fish", 0.35))
|
||||
.add(docForClassification(indexName, "dog", "cat", 0.5))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.4))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.35))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.6))
|
||||
.add(docForClassification(indexName, "ant", "cat", 0.1));
|
||||
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
|
||||
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
|
||||
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
|
||||
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
|
||||
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
|
||||
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
|
||||
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
|
||||
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
|
||||
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
{ // AucRoc
|
||||
|
@ -1957,8 +1957,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
|
||||
assertThat(aucRocResult.getDocCount(), equalTo(5L));
|
||||
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
|
||||
assertNotNull(aucRocResult.getCurve());
|
||||
}
|
||||
{ // Accuracy
|
||||
|
@ -2173,21 +2172,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
|
||||
private static IndexRequest docForClassification(String indexName,
|
||||
String actualClass,
|
||||
String... topPredictedClasses) {
|
||||
assert topPredictedClasses.length > 0;
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON,
|
||||
actualClassField, actualClass,
|
||||
predictedClassField, predictedClass,
|
||||
topClassesField, Arrays.asList(
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", predictedClass);
|
||||
put("class_probability", p);
|
||||
}},
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", "other");
|
||||
put("class_probability", 1 - p);
|
||||
}}));
|
||||
predictedClassField, topPredictedClasses[0],
|
||||
topClassesField, IntStream.range(0, topPredictedClasses.length)
|
||||
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
|
||||
.mapToObj(i -> new HashMap<String, Object>() {{
|
||||
put("class_name", topPredictedClasses[i]);
|
||||
put("class_probability", 1.0 / (2 << i));
|
||||
}})
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
|
|
|
@ -201,7 +201,6 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
|
|||
import org.elasticsearch.client.ml.job.results.Influencer;
|
||||
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||
import org.elasticsearch.common.TriFunction;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
|
@ -229,8 +228,11 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
@ -3463,34 +3465,33 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.endObject()
|
||||
.endObject()
|
||||
.endObject());
|
||||
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
|
||||
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
|
||||
assert topPredictedClasses.length > 0;
|
||||
return new IndexRequest()
|
||||
.source(XContentType.JSON,
|
||||
"actual_class", actualClass,
|
||||
"predicted_class", predictedClass,
|
||||
"ml.top_classes", Arrays.asList(
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", predictedClass);
|
||||
put("class_probability", p);
|
||||
}},
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", "other");
|
||||
put("class_probability", 1 - p);
|
||||
}}));
|
||||
"predicted_class", topPredictedClasses[0],
|
||||
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
|
||||
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
|
||||
.mapToObj(i -> new HashMap<String, Object>() {{
|
||||
put("class_name", topPredictedClasses[i]);
|
||||
put("class_probability", 1.0 / (2 << i));
|
||||
}})
|
||||
.collect(toList()));
|
||||
};
|
||||
BulkRequest bulkRequest =
|
||||
new BulkRequest(indexName)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
|
||||
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
|
||||
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
|
||||
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
|
||||
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
|
||||
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
|
||||
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
|
||||
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
|
||||
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
|
||||
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
|
||||
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
|
||||
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
|
||||
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
|
||||
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
|
||||
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
|
@ -3530,7 +3531,6 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
|
||||
double aucRocScore = aucRocResult.getScore(); // <11>
|
||||
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
|
@ -3565,8 +3565,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(otherClassesCount, equalTo(0L));
|
||||
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||
assertThat(aucRocScore, equalTo(0.2625));
|
||||
assertThat(aucRocDocCount, equalTo(5L));
|
||||
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
|
|||
public static AucRocMetric.Result randomResult() {
|
||||
return new AucRocMetric.Result(
|
||||
randomDouble(),
|
||||
randomLong(),
|
||||
Stream
|
||||
.generate(AucRocMetricAucRocPointTests::randomPoint)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
|
|
|
@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
|
|||
<9> Fetching the number of classes that were not included in the matrix
|
||||
<10> Fetching AucRoc metric by name
|
||||
<11> Fetching the actual AucRoc score
|
||||
<12> Fetching the number of documents that were used in order to calculate AucRoc score
|
||||
|
||||
===== Regression
|
||||
|
||||
|
|
|
@ -193,10 +193,8 @@ belongs.
|
|||
`class_name`::::
|
||||
(Required, string) Name of the only class that will be treated as
|
||||
positive during AUC ROC calculation. Other classes will be treated as
|
||||
negative ("one-vs-all" strategy). Documents which do not have `class_name`
|
||||
in the list of their top classes will not be taken into account for
|
||||
evaluation. The number of documents taken into account is returned in the
|
||||
evaluation result (`auc_roc.doc_count` field).
|
||||
negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
|
||||
in the list of their top classes.
|
||||
|
||||
`include_curve`::::
|
||||
(Optional, boolean) Whether or not the curve should be returned in
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -231,26 +230,18 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
|
|||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String SCORE = "score";
|
||||
private static final String DOC_COUNT = "doc_count";
|
||||
private static final String CURVE = "curve";
|
||||
|
||||
private final double score;
|
||||
private final Long docCount;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, Long docCount, List<AucRocPoint> curve) {
|
||||
public Result(double score, List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.docCount = docCount;
|
||||
this.curve = Objects.requireNonNull(curve);
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.score = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.docCount = in.readOptionalLong();
|
||||
} else {
|
||||
this.docCount = null;
|
||||
}
|
||||
this.curve = in.readList(AucRocPoint::new);
|
||||
}
|
||||
|
||||
|
@ -258,10 +249,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
|
|||
return score;
|
||||
}
|
||||
|
||||
public Long getDocCount() {
|
||||
return docCount;
|
||||
}
|
||||
|
||||
public List<AucRocPoint> getCurve() {
|
||||
return Collections.unmodifiableList(curve);
|
||||
}
|
||||
|
@ -279,9 +266,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(score);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeOptionalLong(docCount);
|
||||
}
|
||||
out.writeList(curve);
|
||||
}
|
||||
|
||||
|
@ -289,9 +273,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SCORE, score);
|
||||
if (docCount != null) {
|
||||
builder.field(DOC_COUNT, docCount);
|
||||
}
|
||||
if (curve.isEmpty() == false) {
|
||||
builder.field(CURVE, curve);
|
||||
}
|
||||
|
@ -305,13 +286,12 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return score == that.score
|
||||
&& Objects.equals(docCount, that.docCount)
|
||||
&& Objects.equals(curve, that.curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(score, docCount, curve);
|
||||
return Objects.hash(score, curve);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -183,42 +183,39 @@ public class AucRoc extends AbstractAucRoc {
|
|||
Filter classAgg = aggs.get(TRUE_AGG_NAME);
|
||||
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
|
||||
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
|
||||
|
||||
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
|
||||
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
|
||||
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
|
||||
|
||||
if (classAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getActualField(), className);
|
||||
}
|
||||
if (classNestedFilter.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getPredictedClassField(), className);
|
||||
}
|
||||
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
|
||||
double[] tpPercentiles = percentilesArray(classPercentiles);
|
||||
|
||||
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
|
||||
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
|
||||
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
|
||||
if (restAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have a different value than [{}]",
|
||||
getName(), fields.get().getActualField(), className);
|
||||
}
|
||||
if (restNestedFilter.getDocCount() == 0) {
|
||||
long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
|
||||
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
|
||||
if (filteredDocCount < totalDocCount) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getPredictedClassField(), className);
|
||||
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). "
|
||||
+ "This is probably caused by the {} value being less than the total number of actual classes in the dataset.",
|
||||
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount,
|
||||
org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
|
||||
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
|
||||
double[] tpPercentiles = percentilesArray(classPercentiles);
|
||||
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
|
||||
double[] fpPercentiles = percentilesArray(restPercentiles);
|
||||
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
result.set(
|
||||
new Result(
|
||||
aucRocScore,
|
||||
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
|
||||
includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -174,11 +174,7 @@ public class AucRoc extends AbstractAucRoc {
|
|||
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
result.set(
|
||||
new Result(
|
||||
aucRocScore,
|
||||
classAgg.getDocCount() + restAgg.getDocCount(),
|
||||
includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,13 +20,12 @@ public class AucRocResultTests extends AbstractWireSerializingTestCase<Result> {
|
|||
|
||||
public static Result createRandom() {
|
||||
double score = randomDoubleBetween(0.0, 1.0, true);
|
||||
Long docCount = randomBoolean() ? randomLong() : null;
|
||||
List<AucRocPoint> curve =
|
||||
Stream
|
||||
.generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble()))
|
||||
.limit(randomIntBetween(0, 20))
|
||||
.collect(Collectors.toList());
|
||||
return new Result(score, docCount, curve);
|
||||
return new Result(score, curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -167,14 +167,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(false);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
|
||||
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(0));
|
||||
}
|
||||
|
||||
public void testEvaluate_AucRoc_IncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(true);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
|
||||
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
|
|
|
@ -981,7 +981,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1);
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
|
||||
assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||
assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L)));
|
||||
assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
|
|
|
@ -207,7 +207,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_false: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
|
@ -228,7 +227,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_false: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
|
@ -249,7 +247,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_true: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
|
@ -415,7 +412,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- is_true: outlier_detection.auc_roc.score
|
||||
- is_true: outlier_detection.auc_roc.doc_count
|
||||
- is_true: outlier_detection.precision.0\.25
|
||||
- is_true: outlier_detection.precision.0\.5
|
||||
- is_true: outlier_detection.precision.0\.75
|
||||
|
@ -689,7 +685,7 @@ setup:
|
|||
---
|
||||
"Test classification auc_roc given predicted_class_field is never equal to mouse":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/
|
||||
catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared in 0 out of 8\)./
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -726,7 +722,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { classification.auc_roc.score: 0.8050111095212122 }
|
||||
- match: { classification.auc_roc.doc_count: 8 }
|
||||
- is_false: classification.auc_roc.curve
|
||||
---
|
||||
"Test classification auc_roc with default top_classes_field":
|
||||
|
@ -747,7 +742,6 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { classification.auc_roc.score: 0.8050111095212122 }
|
||||
- match: { classification.auc_roc.doc_count: 8 }
|
||||
- is_false: classification.auc_roc.curve
|
||||
---
|
||||
"Test classification accuracy with missing predicted_field":
|
||||
|
|
Loading…
Reference in New Issue