This commit is contained in:
parent
b95ca9c3bb
commit
7512337922
|
@ -21,7 +21,9 @@ package org.elasticsearch.client.ml;
|
|||
|
||||
import org.elasticsearch.client.Validatable;
|
||||
import org.elasticsearch.client.ValidationException;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
@ -37,20 +39,25 @@ import java.util.Objects;
|
|||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
|
||||
|
||||
public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
|
||||
|
||||
private static final ParseField INDEX = new ParseField("index");
|
||||
private static final ParseField QUERY = new ParseField("query");
|
||||
private static final ParseField EVALUATION = new ParseField("evaluation");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<EvaluateDataFrameRequest, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List<String>) args[0], (Evaluation) args[1]));
|
||||
"evaluate_data_frame_request",
|
||||
true,
|
||||
args -> new EvaluateDataFrameRequest((List<String>) args[0], (QueryConfig) args[1], (Evaluation) args[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareStringArray(constructorArg(), INDEX);
|
||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY);
|
||||
PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
|
||||
}
|
||||
|
||||
|
@ -67,14 +74,16 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
|
|||
}
|
||||
|
||||
private List<String> indices;
|
||||
private QueryConfig queryConfig;
|
||||
private Evaluation evaluation;
|
||||
|
||||
public EvaluateDataFrameRequest(String index, Evaluation evaluation) {
|
||||
this(Arrays.asList(index), evaluation);
|
||||
public EvaluateDataFrameRequest(String index, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
|
||||
this(Arrays.asList(index), queryConfig, evaluation);
|
||||
}
|
||||
|
||||
public EvaluateDataFrameRequest(List<String> indices, Evaluation evaluation) {
|
||||
public EvaluateDataFrameRequest(List<String> indices, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
|
||||
setIndices(indices);
|
||||
setQueryConfig(queryConfig);
|
||||
setEvaluation(evaluation);
|
||||
}
|
||||
|
||||
|
@ -87,6 +96,14 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
|
|||
this.indices = new ArrayList<>(indices);
|
||||
}
|
||||
|
||||
public QueryConfig getQueryConfig() {
|
||||
return queryConfig;
|
||||
}
|
||||
|
||||
public final void setQueryConfig(QueryConfig queryConfig) {
|
||||
this.queryConfig = queryConfig;
|
||||
}
|
||||
|
||||
public Evaluation getEvaluation() {
|
||||
return evaluation;
|
||||
}
|
||||
|
@ -111,18 +128,22 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
|
|||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder
|
||||
.startObject()
|
||||
.array(INDEX.getPreferredName(), indices.toArray())
|
||||
.startObject(EVALUATION.getPreferredName())
|
||||
.field(evaluation.getName(), evaluation)
|
||||
.endObject()
|
||||
builder.startObject();
|
||||
builder.array(INDEX.getPreferredName(), indices.toArray());
|
||||
if (queryConfig != null) {
|
||||
builder.field(QUERY.getPreferredName(), queryConfig.getQuery());
|
||||
}
|
||||
builder
|
||||
.startObject(EVALUATION.getPreferredName())
|
||||
.field(evaluation.getName(), evaluation)
|
||||
.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(indices, evaluation);
|
||||
return Objects.hash(indices, queryConfig, evaluation);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,6 +152,7 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o;
|
||||
return Objects.equals(indices, that.indices)
|
||||
&& Objects.equals(queryConfig, that.queryConfig)
|
||||
&& Objects.equals(evaluation, that.evaluation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
|
|||
import org.elasticsearch.client.ml.DeleteJobRequest;
|
||||
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
|
||||
import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
|
||||
import org.elasticsearch.client.ml.FindFileStructureRequest;
|
||||
import org.elasticsearch.client.ml.FindFileStructureRequestTests;
|
||||
import org.elasticsearch.client.ml.FlushJobRequest;
|
||||
|
@ -85,9 +86,6 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests;
|
|||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.client.ml.job.config.Detector;
|
||||
|
@ -779,13 +777,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testEvaluateDataFrame() throws IOException {
|
||||
EvaluateDataFrameRequest evaluateRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
Arrays.asList(generateRandomStringArray(1, 10, false, false)),
|
||||
new BinarySoftClassification(
|
||||
randomAlphaOfLengthBetween(1, 10),
|
||||
randomAlphaOfLengthBetween(1, 10),
|
||||
PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7)));
|
||||
EvaluateDataFrameRequest evaluateRequest = EvaluateDataFrameRequestTests.createRandom();
|
||||
Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest);
|
||||
assertEquals(HttpPost.METHOD_NAME, request.getMethod());
|
||||
assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint());
|
||||
|
|
|
@ -149,6 +149,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.junit.After;
|
||||
|
@ -1455,7 +1456,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
public void testStopDataFrameAnalyticsConfig() throws Exception {
|
||||
String sourceIndex = "stop-test-source-index";
|
||||
String destIndex = "stop-test-dest-index";
|
||||
createIndex(sourceIndex, mappingForClassification());
|
||||
createIndex(sourceIndex, defaultMappingForTest());
|
||||
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
|
||||
|
||||
|
@ -1553,27 +1554,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(exception.status().getStatus(), equalTo(404));
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame() throws IOException {
|
||||
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
|
||||
String indexName = "evaluate-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
BulkRequest bulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, false, 0.1)) // #0
|
||||
.add(docForClassification(indexName, false, 0.2)) // #1
|
||||
.add(docForClassification(indexName, false, 0.3)) // #2
|
||||
.add(docForClassification(indexName, false, 0.4)) // #3
|
||||
.add(docForClassification(indexName, false, 0.7)) // #4
|
||||
.add(docForClassification(indexName, true, 0.2)) // #5
|
||||
.add(docForClassification(indexName, true, 0.3)) // #6
|
||||
.add(docForClassification(indexName, true, 0.4)) // #7
|
||||
.add(docForClassification(indexName, true, 0.8)) // #8
|
||||
.add(docForClassification(indexName, true, 0.9)); // #9
|
||||
.add(docForClassification(indexName, "blue", false, 0.1)) // #0
|
||||
.add(docForClassification(indexName, "blue", false, 0.2)) // #1
|
||||
.add(docForClassification(indexName, "blue", false, 0.3)) // #2
|
||||
.add(docForClassification(indexName, "blue", false, 0.4)) // #3
|
||||
.add(docForClassification(indexName, "blue", false, 0.7)) // #4
|
||||
.add(docForClassification(indexName, "blue", true, 0.2)) // #5
|
||||
.add(docForClassification(indexName, "green", true, 0.3)) // #6
|
||||
.add(docForClassification(indexName, "green", true, 0.4)) // #7
|
||||
.add(docForClassification(indexName, "green", true, 0.8)) // #8
|
||||
.add(docForClassification(indexName, "green", true, 0.9)); // #9
|
||||
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new BinarySoftClassification(
|
||||
actualField,
|
||||
probabilityField,
|
||||
|
@ -1624,7 +1626,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
|
||||
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
|
||||
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
|
||||
String indexName = "evaluate-with-query-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
BulkRequest bulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #0
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #1
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #2
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #3
|
||||
.add(docForClassification(indexName, "blue", true, 0.0)) // #4
|
||||
.add(docForClassification(indexName, "blue", true, 0.0)) // #5
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #6
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #7
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #8
|
||||
.add(docForClassification(indexName, "green", true, 1.0)); // #9
|
||||
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
// Request only "blue" subset to be evaluated
|
||||
new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")),
|
||||
new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5)));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME));
|
||||
ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5");
|
||||
assertThat(confusionMatrix.getTruePositives(), equalTo(4L)); // docs #0, #1, #2 and #3
|
||||
assertThat(confusionMatrix.getFalsePositives(), equalTo(0L));
|
||||
assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L));
|
||||
assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L)); // docs #4 and #5
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame_Regression() throws IOException {
|
||||
String regressionIndex = "evaluate-regression-test-index";
|
||||
createIndex(regressionIndex, mappingForRegression());
|
||||
BulkRequest regressionBulk = new BulkRequest()
|
||||
|
@ -1641,10 +1684,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
|
||||
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
regressionIndex,
|
||||
null,
|
||||
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
|
||||
evaluateDataFrameResponse =
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
||||
|
@ -1671,12 +1718,16 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.endObject();
|
||||
}
|
||||
|
||||
private static final String datasetField = "dataset";
|
||||
private static final String actualField = "label";
|
||||
private static final String probabilityField = "p";
|
||||
|
||||
private static XContentBuilder mappingForClassification() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject(datasetField)
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.startObject(actualField)
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
|
@ -1687,10 +1738,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) {
|
||||
private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
|
||||
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
|
@ -1725,7 +1776,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
BulkRequest bulk1 = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
}
|
||||
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
|
||||
|
||||
|
@ -1751,7 +1802,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
BulkRequest bulk2 = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 10; i < 100; ++i) {
|
||||
bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
}
|
||||
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
|
||||
|
||||
|
|
|
@ -178,7 +178,6 @@ import org.elasticsearch.search.aggregations.AggregatorFactories;
|
|||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.hamcrest.CoreMatchers;
|
||||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -3178,16 +3177,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
BulkRequest bulkRequest =
|
||||
new BulkRequest(indexName)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8
|
||||
.add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.1)) // #0
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.2)) // #1
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.3)) // #2
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.4)) // #3
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.7)) // #4
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.2)) // #5
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.3)) // #6
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.4)) // #7
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.8)) // #8
|
||||
.add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true, "p", 0.9)); // #9
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
|
@ -3195,14 +3194,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
// tag::evaluate-data-frame-request
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
|
||||
indexName, // <2>
|
||||
new BinarySoftClassification( // <3>
|
||||
"label", // <4>
|
||||
"p", // <5>
|
||||
// Evaluation metrics // <6>
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <7>
|
||||
RecallMetric.at(0.5, 0.7), // <8>
|
||||
ConfusionMatrixMetric.at(0.5), // <9>
|
||||
AucRocMetric.withCurve())); // <10>
|
||||
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
|
||||
new BinarySoftClassification( // <4>
|
||||
"label", // <5>
|
||||
"p", // <6>
|
||||
// Evaluation metrics // <7>
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
|
||||
RecallMetric.at(0.5, 0.7), // <9>
|
||||
ConfusionMatrixMetric.at(0.5), // <10>
|
||||
AucRocMetric.withCurve())); // <11>
|
||||
// end::evaluate-data-frame-request
|
||||
|
||||
// tag::evaluate-data-frame-execute
|
||||
|
@ -3223,14 +3223,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
|
||||
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
|
||||
assertThat(precision, closeTo(0.6, 1e-9));
|
||||
assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9
|
||||
assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4
|
||||
assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3
|
||||
assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7
|
||||
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
|
||||
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
|
||||
assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3
|
||||
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
|
||||
}
|
||||
{
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")),
|
||||
new BinarySoftClassification(
|
||||
"label",
|
||||
"p",
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RegressionTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class EvaluateDataFrameRequestTests extends AbstractXContentTestCase<EvaluateDataFrameRequest> {
|
||||
|
||||
public static EvaluateDataFrameRequest createRandom() {
|
||||
int indicesCount = randomIntBetween(1, 5);
|
||||
List<String> indices = new ArrayList<>(indicesCount);
|
||||
for (int i = 0; i < indicesCount; i++) {
|
||||
indices.add(randomAlphaOfLength(10));
|
||||
}
|
||||
QueryConfig queryConfig = randomBoolean()
|
||||
? new QueryConfig(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)))
|
||||
: null;
|
||||
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
|
||||
return new EvaluateDataFrameRequest(indices, queryConfig, evaluation);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EvaluateDataFrameRequest createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EvaluateDataFrameRequest doParseInstance(XContentParser parser) throws IOException {
|
||||
return EvaluateDataFrameRequest.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in root only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
}
|
|
@ -36,8 +36,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
public static Regression createRandom() {
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new MeanSquaredErrorMetric());
|
||||
|
@ -50,6 +49,11 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
return Regression.fromXContent(parser);
|
||||
|
|
|
@ -37,8 +37,7 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
|
|||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinarySoftClassification createTestInstance() {
|
||||
public static BinarySoftClassification createRandom() {
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new AucRocMetric(randomBoolean()));
|
||||
|
@ -66,6 +65,11 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
|
|||
new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinarySoftClassification createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException {
|
||||
return BinarySoftClassification.fromXContent(parser);
|
||||
|
|
|
@ -18,14 +18,15 @@ include-tagged::{doc-tests-file}[{api}-request]
|
|||
--------------------------------------------------
|
||||
<1> Constructing a new evaluation request
|
||||
<2> Reference to an existing index
|
||||
<3> Kind of evaluation to perform
|
||||
<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
|
||||
<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
|
||||
<6> The remaining parameters are the metrics to be calculated based on the two fields described above.
|
||||
<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
|
||||
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
|
||||
<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
|
||||
<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
|
||||
<3> The query with which to select data from indices
|
||||
<4> Kind of evaluation to perform
|
||||
<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
|
||||
<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
|
||||
<7> The remaining parameters are the metrics to be calculated based on the two fields described above.
|
||||
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
|
||||
<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
|
||||
<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
|
||||
<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
|
|
|
@ -43,7 +43,13 @@ packages together commonly used metrics for various analyses.
|
|||
`index`::
|
||||
(Required, object) Defines the `index` in which the evaluation will be
|
||||
performed.
|
||||
|
||||
|
||||
`query`::
|
||||
(Optional, object) Query used to select data from the index.
|
||||
The {es} query domain-specific language (DSL). This value corresponds to the query
|
||||
object in an {es} search POST body. By default, this property has the following
|
||||
value: `{"match_all": {}}`.
|
||||
|
||||
`evaluation`::
|
||||
(Required, object) Defines the type of evaluation you want to perform. For example:
|
||||
`binary_soft_classification`. See <<ml-evaluate-dfanalytics-resources>>.
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionRequestBuilder;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.client.ElasticsearchClient;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -20,14 +22,21 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParserUtils;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.Response> {
|
||||
|
||||
|
@ -41,14 +50,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
public static class Request extends ActionRequest implements ToXContentObject {
|
||||
|
||||
private static final ParseField INDEX = new ParseField("index");
|
||||
private static final ParseField QUERY = new ParseField("query");
|
||||
private static final ParseField EVALUATION = new ParseField("evaluation");
|
||||
|
||||
private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
a -> new Request((List<String>) a[0], (Evaluation) a[1]));
|
||||
private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
a -> new Request((List<String>) a[0], (QueryProvider) a[1], (Evaluation) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
|
||||
PARSER.declareStringArray(constructorArg(), INDEX);
|
||||
PARSER.declareObject(
|
||||
optionalConstructorArg(),
|
||||
(p, c) -> QueryProvider.fromXContent(p, true, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT),
|
||||
QUERY);
|
||||
PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
|
||||
}
|
||||
|
||||
private static Evaluation parseEvaluation(XContentParser parser) throws IOException {
|
||||
|
@ -64,19 +79,25 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
}
|
||||
|
||||
private String[] indices;
|
||||
private QueryProvider queryProvider;
|
||||
private Evaluation evaluation;
|
||||
|
||||
private Request(List<String> indices, Evaluation evaluation) {
|
||||
private Request(List<String> indices, @Nullable QueryProvider queryProvider, Evaluation evaluation) {
|
||||
setIndices(indices);
|
||||
setQueryProvider(queryProvider);
|
||||
setEvaluation(evaluation);
|
||||
}
|
||||
|
||||
public Request() {
|
||||
}
|
||||
public Request() {}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
indices = in.readStringArray();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_4_0)) {
|
||||
if (in.readBoolean()) {
|
||||
queryProvider = QueryProvider.fromStream(in);
|
||||
}
|
||||
}
|
||||
evaluation = in.readNamedWriteable(Evaluation.class);
|
||||
}
|
||||
|
||||
|
@ -92,6 +113,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
this.indices = indices.toArray(new String[indices.size()]);
|
||||
}
|
||||
|
||||
public QueryBuilder getParsedQuery() {
|
||||
return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
|
||||
}
|
||||
|
||||
public final void setQueryProvider(QueryProvider queryProvider) {
|
||||
this.queryProvider = queryProvider;
|
||||
}
|
||||
|
||||
public Evaluation getEvaluation() {
|
||||
return evaluation;
|
||||
}
|
||||
|
@ -109,6 +138,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeStringArray(indices);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
|
||||
if (queryProvider != null) {
|
||||
out.writeBoolean(true);
|
||||
queryProvider.writeTo(out);
|
||||
} else {
|
||||
out.writeBoolean(false);
|
||||
}
|
||||
}
|
||||
out.writeNamedWriteable(evaluation);
|
||||
}
|
||||
|
||||
|
@ -116,16 +153,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.array(INDEX.getPreferredName(), indices);
|
||||
builder.startObject(EVALUATION.getPreferredName());
|
||||
builder.field(evaluation.getName(), evaluation);
|
||||
builder.endObject();
|
||||
if (queryProvider != null) {
|
||||
builder.field(QUERY.getPreferredName(), queryProvider.getQuery());
|
||||
}
|
||||
builder
|
||||
.startObject(EVALUATION.getPreferredName())
|
||||
.field(evaluation.getName(), evaluation)
|
||||
.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(Arrays.hashCode(indices), evaluation);
|
||||
return Objects.hash(Arrays.hashCode(indices), queryProvider, evaluation);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,7 +174,9 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Request that = (Request) o;
|
||||
return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation);
|
||||
return Arrays.equals(indices, that.indices)
|
||||
&& Objects.equals(queryProvider, that.queryProvider)
|
||||
&& Objects.equals(evaluation, that.evaluation);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,5 +243,4 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -143,7 +143,8 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject {
|
|||
return deprecations;
|
||||
}
|
||||
|
||||
public Map<String, Object> getQuery() {
|
||||
// Visible for testing
|
||||
Map<String, Object> getQuery() {
|
||||
return queryProvider.getQuery();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener;
|
|||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -25,8 +26,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
|
||||
/**
|
||||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param queryBuilder User-provided query that must be respected when collecting data
|
||||
*/
|
||||
SearchSourceBuilder buildSearch();
|
||||
SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
|
||||
|
||||
/**
|
||||
* Computes the evaluation result
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
@ -106,10 +107,11 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch() {
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery(actualField))
|
||||
.filter(QueryBuilders.existsQuery(predictedField));
|
||||
.filter(QueryBuilders.existsQuery(predictedField))
|
||||
.filter(queryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (RegressionMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
|
|
|
@ -155,10 +155,12 @@ public class BinarySoftClassification implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch() {
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
|
||||
searchSourceBuilder.size(0);
|
||||
searchSourceBuilder.query(buildQuery());
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery(actualField))
|
||||
.filter(QueryBuilders.existsQuery(predictedProbabilityField))
|
||||
.filter(queryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
|
@ -166,13 +168,6 @@ public class BinarySoftClassification implements Evaluation {
|
|||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
private QueryBuilder buildQuery() {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
boolQuery.filter(QueryBuilders.existsQuery(actualField));
|
||||
boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField));
|
||||
return boolQuery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
|
|
|
@ -7,26 +7,41 @@ package org.elasticsearch.xpack.core.ml.action;
|
|||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -38,7 +53,18 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
|
|||
indices.add(randomAlphaOfLength(10));
|
||||
}
|
||||
request.setIndices(indices);
|
||||
request.setEvaluation(BinarySoftClassificationTests.createRandom());
|
||||
QueryProvider queryProvider = null;
|
||||
if (randomBoolean()) {
|
||||
try {
|
||||
queryProvider = QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
} catch (IOException e) {
|
||||
// Should never happen
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
request.setQueryProvider(queryProvider);
|
||||
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
|
||||
request.setEvaluation(evaluation);
|
||||
return request;
|
||||
}
|
||||
|
||||
|
|
|
@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -69,4 +72,20 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
() -> new Regression("foo", "bar", Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(QueryBuilders.existsQuery("prob"))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -76,4 +79,20 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
|||
() -> new BinarySoftClassification("foo", "bar", Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("[binary_soft_classification] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(QueryBuilders.existsQuery("prob"))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
ActionListener<EvaluateDataFrameAction.Response> listener) {
|
||||
Evaluation evaluation = request.getEvaluation();
|
||||
SearchRequest searchRequest = new SearchRequest(request.getIndices());
|
||||
searchRequest.source(evaluation.buildSearch());
|
||||
searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
|
||||
|
||||
ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
|
||||
results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),
|
||||
|
|
|
@ -5,6 +5,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "blue",
|
||||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.0,
|
||||
|
@ -19,6 +20,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "blue",
|
||||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.2,
|
||||
|
@ -33,6 +35,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "blue",
|
||||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.3,
|
||||
|
@ -47,6 +50,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "blue",
|
||||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.3,
|
||||
|
@ -61,6 +65,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "green",
|
||||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.4,
|
||||
|
@ -75,6 +80,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "green",
|
||||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.5,
|
||||
|
@ -89,6 +95,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "green",
|
||||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.9,
|
||||
|
@ -103,6 +110,7 @@ setup:
|
|||
index: utopia
|
||||
body: >
|
||||
{
|
||||
"dataset": "green",
|
||||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.95,
|
||||
|
@ -305,6 +313,33 @@ setup:
|
|||
tn: 3
|
||||
fn: 2
|
||||
|
||||
---
|
||||
"Test binary_soft_classification with query":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"query": { "bool": { "filter": { "term": { "dataset": "blue" } } } },
|
||||
"evaluation": {
|
||||
"binary_soft_classification": {
|
||||
"actual_field": "is_outlier",
|
||||
"predicted_probability_field": "outlier_score",
|
||||
"metrics": {
|
||||
"confusion_matrix": { "at": [0.5] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match:
|
||||
binary_soft_classification:
|
||||
confusion_matrix:
|
||||
'0.5':
|
||||
tp: 0
|
||||
fp: 0
|
||||
tn: 3
|
||||
fn: 1
|
||||
|
||||
---
|
||||
"Test binary_soft_classification default metrics":
|
||||
- do:
|
||||
|
|
Loading…
Reference in New Issue