[7.x] Allow the user to specify 'query' in Evaluate Data Frame request (#45775) (#45825)

This commit is contained in:
Przemysław Witek 2019-08-22 11:14:26 +02:00 committed by GitHub
parent b95ca9c3bb
commit 7512337922
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 412 additions and 108 deletions

View File

@ -21,7 +21,9 @@ package org.elasticsearch.client.ml;
import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -37,20 +39,25 @@ import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
private static final ParseField INDEX = new ParseField("index");
private static final ParseField QUERY = new ParseField("query");
private static final ParseField EVALUATION = new ParseField("evaluation");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<EvaluateDataFrameRequest, Void> PARSER =
new ConstructingObjectParser<>(
"evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List<String>) args[0], (Evaluation) args[1]));
"evaluate_data_frame_request",
true,
args -> new EvaluateDataFrameRequest((List<String>) args[0], (QueryConfig) args[1], (Evaluation) args[2]));
static {
PARSER.declareStringArray(constructorArg(), INDEX);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY);
PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
}
@ -67,14 +74,16 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
}
private List<String> indices;
private QueryConfig queryConfig;
private Evaluation evaluation;
public EvaluateDataFrameRequest(String index, Evaluation evaluation) {
this(Arrays.asList(index), evaluation);
public EvaluateDataFrameRequest(String index, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
this(Arrays.asList(index), queryConfig, evaluation);
}
public EvaluateDataFrameRequest(List<String> indices, Evaluation evaluation) {
public EvaluateDataFrameRequest(List<String> indices, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
setIndices(indices);
setQueryConfig(queryConfig);
setEvaluation(evaluation);
}
@ -87,6 +96,14 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
this.indices = new ArrayList<>(indices);
}
public QueryConfig getQueryConfig() {
return queryConfig;
}
public final void setQueryConfig(QueryConfig queryConfig) {
this.queryConfig = queryConfig;
}
public Evaluation getEvaluation() {
return evaluation;
}
@ -111,18 +128,22 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder
.startObject()
.array(INDEX.getPreferredName(), indices.toArray())
.startObject(EVALUATION.getPreferredName())
.field(evaluation.getName(), evaluation)
.endObject()
builder.startObject();
builder.array(INDEX.getPreferredName(), indices.toArray());
if (queryConfig != null) {
builder.field(QUERY.getPreferredName(), queryConfig.getQuery());
}
builder
.startObject(EVALUATION.getPreferredName())
.field(evaluation.getName(), evaluation)
.endObject();
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(indices, evaluation);
return Objects.hash(indices, queryConfig, evaluation);
}
@Override
@ -131,6 +152,7 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
if (o == null || getClass() != o.getClass()) return false;
EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o;
return Objects.equals(indices, that.indices)
&& Objects.equals(queryConfig, that.queryConfig)
&& Objects.equals(evaluation, that.evaluation);
}
}

View File

@ -36,6 +36,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
import org.elasticsearch.client.ml.FindFileStructureRequest;
import org.elasticsearch.client.ml.FindFileStructureRequestTests;
import org.elasticsearch.client.ml.FlushJobRequest;
@ -85,9 +86,6 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.Detector;
@ -779,13 +777,7 @@ public class MLRequestConvertersTests extends ESTestCase {
}
public void testEvaluateDataFrame() throws IOException {
EvaluateDataFrameRequest evaluateRequest =
new EvaluateDataFrameRequest(
Arrays.asList(generateRandomStringArray(1, 10, false, false)),
new BinarySoftClassification(
randomAlphaOfLengthBetween(1, 10),
randomAlphaOfLengthBetween(1, 10),
PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7)));
EvaluateDataFrameRequest evaluateRequest = EvaluateDataFrameRequestTests.createRandom();
Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest);
assertEquals(HttpPost.METHOD_NAME, request.getMethod());
assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint());

View File

@ -149,6 +149,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.junit.After;
@ -1455,7 +1456,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
public void testStopDataFrameAnalyticsConfig() throws Exception {
String sourceIndex = "stop-test-source-index";
String destIndex = "stop-test-dest-index";
createIndex(sourceIndex, mappingForClassification());
createIndex(sourceIndex, defaultMappingForTest());
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
@ -1553,27 +1554,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(exception.status().getStatus(), equalTo(404));
}
public void testEvaluateDataFrame() throws IOException {
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
String indexName = "evaluate-test-index";
createIndex(indexName, mappingForClassification());
BulkRequest bulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, false, 0.1)) // #0
.add(docForClassification(indexName, false, 0.2)) // #1
.add(docForClassification(indexName, false, 0.3)) // #2
.add(docForClassification(indexName, false, 0.4)) // #3
.add(docForClassification(indexName, false, 0.7)) // #4
.add(docForClassification(indexName, true, 0.2)) // #5
.add(docForClassification(indexName, true, 0.3)) // #6
.add(docForClassification(indexName, true, 0.4)) // #7
.add(docForClassification(indexName, true, 0.8)) // #8
.add(docForClassification(indexName, true, 0.9)); // #9
.add(docForClassification(indexName, "blue", false, 0.1)) // #0
.add(docForClassification(indexName, "blue", false, 0.2)) // #1
.add(docForClassification(indexName, "blue", false, 0.3)) // #2
.add(docForClassification(indexName, "blue", false, 0.4)) // #3
.add(docForClassification(indexName, "blue", false, 0.7)) // #4
.add(docForClassification(indexName, "blue", true, 0.2)) // #5
.add(docForClassification(indexName, "green", true, 0.3)) // #6
.add(docForClassification(indexName, "green", true, 0.4)) // #7
.add(docForClassification(indexName, "green", true, 0.8)) // #8
.add(docForClassification(indexName, "green", true, 0.9)); // #9
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
null,
new BinarySoftClassification(
actualField,
probabilityField,
@ -1624,7 +1626,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
}
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
String indexName = "evaluate-with-query-test-index";
createIndex(indexName, mappingForClassification());
BulkRequest bulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "blue", true, 1.0)) // #0
.add(docForClassification(indexName, "blue", true, 1.0)) // #1
.add(docForClassification(indexName, "blue", true, 1.0)) // #2
.add(docForClassification(indexName, "blue", true, 1.0)) // #3
.add(docForClassification(indexName, "blue", true, 0.0)) // #4
.add(docForClassification(indexName, "blue", true, 0.0)) // #5
.add(docForClassification(indexName, "green", true, 0.0)) // #6
.add(docForClassification(indexName, "green", true, 0.0)) // #7
.add(docForClassification(indexName, "green", true, 0.0)) // #8
.add(docForClassification(indexName, "green", true, 1.0)); // #9
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
// Request only "blue" subset to be evaluated
new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")),
new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5)));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME);
assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME));
ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5");
assertThat(confusionMatrix.getTruePositives(), equalTo(4L)); // docs #0, #1, #2 and #3
assertThat(confusionMatrix.getFalsePositives(), equalTo(0L));
assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L));
assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L)); // docs #4 and #5
}
public void testEvaluateDataFrame_Regression() throws IOException {
String regressionIndex = "evaluate-regression-test-index";
createIndex(regressionIndex, mappingForRegression());
BulkRequest regressionBulk = new BulkRequest()
@ -1641,10 +1684,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
regressionIndex,
null,
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
evaluateDataFrameResponse =
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
@ -1671,12 +1718,16 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.endObject();
}
private static final String datasetField = "dataset";
private static final String actualField = "label";
private static final String probabilityField = "p";
private static XContentBuilder mappingForClassification() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(datasetField)
.field("type", "keyword")
.endObject()
.startObject(actualField)
.field("type", "keyword")
.endObject()
@ -1687,10 +1738,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.endObject();
}
private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) {
private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
}
private static final String actualRegression = "regression_actual";
@ -1725,7 +1776,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
BulkRequest bulk1 = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 10; ++i) {
bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
}
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
@ -1751,7 +1802,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
BulkRequest bulk2 = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 10; i < 100; ++i) {
bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
}
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);

View File

@ -178,7 +178,6 @@ import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.tasks.TaskId;
import org.hamcrest.CoreMatchers;
import org.junit.After;
import java.io.IOException;
@ -3178,16 +3177,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.1)) // #0
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.2)) // #1
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.3)) // #2
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.4)) // #3
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.7)) // #4
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.2)) // #5
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.3)) // #6
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.4)) // #7
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.8)) // #8
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.9)); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
@ -3195,14 +3194,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::evaluate-data-frame-request
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
indexName, // <2>
new BinarySoftClassification( // <3>
"label", // <4>
"p", // <5>
// Evaluation metrics // <6>
PrecisionMetric.at(0.4, 0.5, 0.6), // <7>
RecallMetric.at(0.5, 0.7), // <8>
ConfusionMatrixMetric.at(0.5), // <9>
AucRocMetric.withCurve())); // <10>
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
new BinarySoftClassification( // <4>
"label", // <5>
"p", // <6>
// Evaluation metrics // <7>
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
RecallMetric.at(0.5, 0.7), // <9>
ConfusionMatrixMetric.at(0.5), // <10>
AucRocMetric.withCurve())); // <11>
// end::evaluate-data-frame-request
// tag::evaluate-data-frame-execute
@ -3223,14 +3223,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
assertThat(precision, closeTo(0.6, 1e-9));
assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9
assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4
assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3
assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
}
{
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(
indexName,
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")),
new BinarySoftClassification(
"label",
"p",

View File

@ -0,0 +1,82 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RegressionTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
public class EvaluateDataFrameRequestTests extends AbstractXContentTestCase<EvaluateDataFrameRequest> {
public static EvaluateDataFrameRequest createRandom() {
int indicesCount = randomIntBetween(1, 5);
List<String> indices = new ArrayList<>(indicesCount);
for (int i = 0; i < indicesCount; i++) {
indices.add(randomAlphaOfLength(10));
}
QueryConfig queryConfig = randomBoolean()
? new QueryConfig(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)))
: null;
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
return new EvaluateDataFrameRequest(indices, queryConfig, evaluation);
}
@Override
protected EvaluateDataFrameRequest createTestInstance() {
return createRandom();
}
@Override
protected EvaluateDataFrameRequest doParseInstance(XContentParser parser) throws IOException {
return EvaluateDataFrameRequest.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// allow unknown fields in root only
return field -> !field.isEmpty();
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -36,8 +36,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected Regression createTestInstance() {
public static Regression createRandom() {
List<EvaluationMetric> metrics = new ArrayList<>();
if (randomBoolean()) {
metrics.add(new MeanSquaredErrorMetric());
@ -50,6 +49,11 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@Override
protected Regression createTestInstance() {
return createRandom();
}
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser);

View File

@ -37,8 +37,7 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected BinarySoftClassification createTestInstance() {
public static BinarySoftClassification createRandom() {
List<EvaluationMetric> metrics = new ArrayList<>();
if (randomBoolean()) {
metrics.add(new AucRocMetric(randomBoolean()));
@ -66,6 +65,11 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@Override
protected BinarySoftClassification createTestInstance() {
return createRandom();
}
@Override
protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException {
return BinarySoftClassification.fromXContent(parser);

View File

@ -18,14 +18,15 @@ include-tagged::{doc-tests-file}[{api}-request]
--------------------------------------------------
<1> Constructing a new evaluation request
<2> Reference to an existing index
<3> Kind of evaluation to perform
<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
<6> The remaining parameters are the metrics to be calculated based on the two fields described above.
<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
<3> The query with which to select data from indices
<4> Kind of evaluation to perform
<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
<7> The remaining parameters are the metrics to be calculated based on the two fields described above.
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
include::../execution.asciidoc[]

View File

@ -43,7 +43,13 @@ packages together commonly used metrics for various analyses.
`index`::
(Required, object) Defines the `index` in which the evaluation will be
performed.
`query`::
(Optional, object) Query used to select data from the index.
The {es} query domain-specific language (DSL). This value corresponds to the query
object in an {es} search POST body. By default, this property has the following
value: `{"match_all": {}}`.
`evaluation`::
(Required, object) Defines the type of evaluation you want to perform. For example:
`binary_soft_classification`. See <<ml-evaluate-dfanalytics-resources>>.

View File

@ -5,12 +5,14 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -20,14 +22,21 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.Response> {
@ -41,14 +50,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
public static class Request extends ActionRequest implements ToXContentObject {
private static final ParseField INDEX = new ParseField("index");
private static final ParseField QUERY = new ParseField("query");
private static final ParseField EVALUATION = new ParseField("evaluation");
private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(NAME,
a -> new Request((List<String>) a[0], (Evaluation) a[1]));
private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(
NAME,
a -> new Request((List<String>) a[0], (QueryProvider) a[1], (Evaluation) a[2]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
PARSER.declareStringArray(constructorArg(), INDEX);
PARSER.declareObject(
optionalConstructorArg(),
(p, c) -> QueryProvider.fromXContent(p, true, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT),
QUERY);
PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
}
private static Evaluation parseEvaluation(XContentParser parser) throws IOException {
@ -64,19 +79,25 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
}
private String[] indices;
private QueryProvider queryProvider;
private Evaluation evaluation;
private Request(List<String> indices, Evaluation evaluation) {
private Request(List<String> indices, @Nullable QueryProvider queryProvider, Evaluation evaluation) {
setIndices(indices);
setQueryProvider(queryProvider);
setEvaluation(evaluation);
}
public Request() {
}
public Request() {}
public Request(StreamInput in) throws IOException {
super(in);
indices = in.readStringArray();
if (in.getVersion().onOrAfter(Version.V_7_4_0)) {
if (in.readBoolean()) {
queryProvider = QueryProvider.fromStream(in);
}
}
evaluation = in.readNamedWriteable(Evaluation.class);
}
@ -92,6 +113,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
this.indices = indices.toArray(new String[indices.size()]);
}
public QueryBuilder getParsedQuery() {
return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
}
public final void setQueryProvider(QueryProvider queryProvider) {
this.queryProvider = queryProvider;
}
public Evaluation getEvaluation() {
return evaluation;
}
@ -109,6 +138,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeStringArray(indices);
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
if (queryProvider != null) {
out.writeBoolean(true);
queryProvider.writeTo(out);
} else {
out.writeBoolean(false);
}
}
out.writeNamedWriteable(evaluation);
}
@ -116,16 +153,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.array(INDEX.getPreferredName(), indices);
builder.startObject(EVALUATION.getPreferredName());
builder.field(evaluation.getName(), evaluation);
builder.endObject();
if (queryProvider != null) {
builder.field(QUERY.getPreferredName(), queryProvider.getQuery());
}
builder
.startObject(EVALUATION.getPreferredName())
.field(evaluation.getName(), evaluation)
.endObject();
builder.endObject();
return builder;
}
@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(indices), evaluation);
return Objects.hash(Arrays.hashCode(indices), queryProvider, evaluation);
}
@Override
@ -133,7 +174,9 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request that = (Request) o;
return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation);
return Arrays.equals(indices, that.indices)
&& Objects.equals(queryProvider, that.queryProvider)
&& Objects.equals(evaluation, that.evaluation);
}
}
@ -200,5 +243,4 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
return Strings.toString(this);
}
}
}

View File

@ -143,7 +143,8 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject {
return deprecations;
}
public Map<String, Object> getQuery() {
// Visible for testing
Map<String, Object> getQuery() {
return queryProvider.getQuery();
}
}

View File

@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.util.List;
@ -25,8 +26,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
/**
* Builds the search required to collect data to compute the evaluation result
* @param queryBuilder User-provided query that must be respected when collecting data
*/
SearchSourceBuilder buildSearch();
SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
/**
* Computes the evaluation result

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -106,10 +107,11 @@ public class Regression implements Evaluation {
}
@Override
public SearchSourceBuilder buildSearch() {
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery(actualField))
.filter(QueryBuilders.existsQuery(predictedField));
.filter(QueryBuilders.existsQuery(predictedField))
.filter(queryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (RegressionMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);

View File

@ -155,10 +155,12 @@ public class BinarySoftClassification implements Evaluation {
}
@Override
public SearchSourceBuilder buildSearch() {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.size(0);
searchSourceBuilder.query(buildQuery());
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery(actualField))
.filter(QueryBuilders.existsQuery(predictedProbabilityField))
.filter(queryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (SoftClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
aggs.forEach(searchSourceBuilder::aggregation);
@ -166,13 +168,6 @@ public class BinarySoftClassification implements Evaluation {
return searchSourceBuilder;
}
private QueryBuilder buildQuery() {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
boolQuery.filter(QueryBuilders.existsQuery(actualField));
boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField));
return boolQuery;
}
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
if (searchResponse.getHits().getTotalHits().value == 0) {

View File

@ -7,26 +7,41 @@ package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTestCase<Request> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables);
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
@ -38,7 +53,18 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
indices.add(randomAlphaOfLength(10));
}
request.setIndices(indices);
request.setEvaluation(BinarySoftClassificationTests.createRandom());
QueryProvider queryProvider = null;
if (randomBoolean()) {
try {
queryProvider = QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)));
} catch (IOException e) {
// Should never happen
throw new UncheckedIOException(e);
}
}
request.setQueryProvider(queryProvider);
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
request.setEvaluation(evaluation);
return request;
}

View File

@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -69,4 +72,20 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
() -> new Regression("foo", "bar", Collections.emptyList()));
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
}
public void testBuildSearch() {
Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(QueryBuilders.existsQuery("prob"))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
}
}

View File

@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -76,4 +79,20 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
() -> new BinarySoftClassification("foo", "bar", Collections.emptyList()));
assertThat(e.getMessage(), equalTo("[binary_soft_classification] must have one or more metrics"));
}
public void testBuildSearch() {
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(QueryBuilders.existsQuery("prob"))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
}
}

View File

@ -40,7 +40,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
ActionListener<EvaluateDataFrameAction.Response> listener) {
Evaluation evaluation = request.getEvaluation();
SearchRequest searchRequest = new SearchRequest(request.getIndices());
searchRequest.source(evaluation.buildSearch());
searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),

View File

@ -5,6 +5,7 @@ setup:
index: utopia
body: >
{
"dataset": "blue",
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.0,
@ -19,6 +20,7 @@ setup:
index: utopia
body: >
{
"dataset": "blue",
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.2,
@ -33,6 +35,7 @@ setup:
index: utopia
body: >
{
"dataset": "blue",
"is_outlier": false,
"is_outlier_int": 0,
"outlier_score": 0.3,
@ -47,6 +50,7 @@ setup:
index: utopia
body: >
{
"dataset": "blue",
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.3,
@ -61,6 +65,7 @@ setup:
index: utopia
body: >
{
"dataset": "green",
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.4,
@ -75,6 +80,7 @@ setup:
index: utopia
body: >
{
"dataset": "green",
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.5,
@ -89,6 +95,7 @@ setup:
index: utopia
body: >
{
"dataset": "green",
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.9,
@ -103,6 +110,7 @@ setup:
index: utopia
body: >
{
"dataset": "green",
"is_outlier": true,
"is_outlier_int": 1,
"outlier_score": 0.95,
@ -305,6 +313,33 @@ setup:
tn: 3
fn: 2
---
"Test binary_soft_classification with query":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"query": { "bool": { "filter": { "term": { "dataset": "blue" } } } },
"evaluation": {
"binary_soft_classification": {
"actual_field": "is_outlier",
"predicted_probability_field": "outlier_score",
"metrics": {
"confusion_matrix": { "at": [0.5] }
}
}
}
}
- match:
binary_soft_classification:
confusion_matrix:
'0.5':
tp: 0
fp: 0
tn: 3
fn: 1
---
"Test binary_soft_classification default metrics":
- do: