[7.x] Add MlClientDocumentationIT tests for classification. (#47569) (#47896)

This commit is contained in:
Przemysław Witek 2019-10-11 10:19:55 +02:00 committed by GitHub
parent e60221d2bd
commit d210bfa888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 286 additions and 40 deletions

View File

@ -1776,7 +1776,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new EvaluateDataFrameRequest(
regressionIndex,
null,
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@ -1933,7 +1933,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
}
private static final String actualRegression = "regression_actual";
private static final String probabilityRegression = "regression_prob";
private static final String predictedRegression = "regression_predicted";
private static XContentBuilder mappingForRegression() throws IOException {
return XContentFactory.jsonBuilder().startObject()
@ -1941,17 +1941,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.startObject(actualRegression)
.field("type", "double")
.endObject()
.startObject(probabilityRegression)
.startObject(predictedRegression)
.field("type", "double")
.endObject()
.endObject()
.endObject();
}
private static IndexRequest docForRegression(String indexName, double act, double p) {
private static IndexRequest docForRegression(String indexName, double actualValue, double predictedValue) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
.source(XContentType.JSON, actualRegression, actualValue, predictedRegression, predictedValue);
}
private void createIndex(String indexName, XContentBuilder mapping) throws IOException {

View File

@ -139,8 +139,11 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@ -2821,7 +2824,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
List<DataFrameAnalyticsConfig> configs = response.getAnalytics();
// end::get-data-frame-analytics-response
assertThat(configs.size(), equalTo(1));
assertThat(configs, hasSize(1));
}
{
GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config");
@ -2871,7 +2874,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
List<DataFrameAnalyticsStats> stats = response.getAnalyticsStats();
// end::get-data-frame-analytics-stats-response
assertThat(stats.size(), equalTo(1));
assertThat(stats, hasSize(1));
}
{
GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config");
@ -2939,8 +2942,20 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.build();
// end::put-data-frame-analytics-outlier-detection-customized
// tag::put-data-frame-analytics-classification
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
.setMaximumNumberTrees(50) // <5>
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.build();
// end::put-data-frame-analytics-classification
// tag::put-data-frame-analytics-regression
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
DataFrameAnalysis regression = org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
@ -3209,18 +3224,24 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-softclassification
Evaluation evaluation =
new BinarySoftClassification( // <1>
"label", // <2>
"p", // <3>
// Evaluation metrics // <4>
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
RecallMetric.at(0.5, 0.7), // <6>
ConfusionMatrixMetric.at(0.5), // <7>
AucRocMetric.withCurve()); // <8>
// end::evaluate-data-frame-evaluation-softclassification
// tag::evaluate-data-frame-request
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
indexName, // <2>
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
new BinarySoftClassification( // <4>
"label", // <5>
"p", // <6>
// Evaluation metrics // <7>
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
RecallMetric.at(0.5, 0.7), // <9>
ConfusionMatrixMetric.at(0.5), // <10>
AucRocMetric.withCurve())); // <11>
EvaluateDataFrameRequest request =
new EvaluateDataFrameRequest( // <1>
indexName, // <2>
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
evaluation); // <4>
// end::evaluate-data-frame-request
// tag::evaluate-data-frame-execute
@ -3229,16 +3250,18 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::evaluate-data-frame-response
List<EvaluationMetric.Result> metrics = response.getMetrics(); // <1>
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2>
double precision = precisionResult.getScoreByThreshold("0.4"); // <3>
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4>
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5>
// end::evaluate-data-frame-response
// tag::evaluate-data-frame-results-softclassification
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <4>
// end::evaluate-data-frame-results-softclassification
assertThat(
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
assertThat(precision, closeTo(0.6, 1e-9));
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
@ -3284,6 +3307,140 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
}
public void testEvaluateDataFrame_Classification() throws Exception {
String indexName = "evaluate-classification-test-index";
CreateIndexRequest createIndexRequest =
new CreateIndexRequest(indexName)
.mapping(XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject("actual_class")
.field("type", "keyword")
.endObject()
.startObject("predicted_class")
.field("type", "keyword")
.endObject()
.endObject()
.endObject());
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-classification
Evaluation evaluation =
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
"actual_class", // <2>
"predicted_class", // <3>
// Evaluation metrics // <4>
new MulticlassConfusionMatrixMetric(3)); // <5>
// end::evaluate-data-frame-evaluation-classification
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
// tag::evaluate-data-frame-results-classification
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
// end::evaluate-data-frame-results-classification
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
equalTo(
new HashMap<String, Map<String, Long>>() {{
put("cat", new HashMap<String, Long>() {{
put("cat", 3L);
put("dog", 1L);
put("ant", 0L);
put("_other_", 1L);
}});
put("dog", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 3L);
put("ant", 0L);
}});
put("ant", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 0L);
put("ant", 0L);
}});
}}));
assertThat(otherClassesCount, equalTo(0L));
}
}
public void testEvaluateDataFrame_Regression() throws Exception {
String indexName = "evaluate-classification-test-index";
CreateIndexRequest createIndexRequest =
new CreateIndexRequest(indexName)
.mapping(XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject("actual_value")
.field("type", "double")
.endObject()
.startObject("predicted_value")
.field("type", "double")
.endObject()
.endObject()
.endObject());
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 1.0)) // #0
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 0.9)) // #1
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.0, "predicted_value", 2.0)) // #2
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.5, "predicted_value", 1.4)) // #3
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.2, "predicted_value", 1.3)) // #4
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.7, "predicted_value", 2.0)) // #5
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.1, "predicted_value", 2.1)) // #6
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.7)) // #7
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 0.8, "predicted_value", 1.0)) // #8
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.4)); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-regression
Evaluation evaluation =
new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( // <1>
"actual_value", // <2>
"predicted_value", // <3>
// Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5>
new RSquaredMetric()); // <6>
// end::evaluate-data-frame-evaluation-regression
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
// tag::evaluate-data-frame-results-regression
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
double rSquared = rSquaredResult.getValue(); // <4>
// end::evaluate-data-frame-results-regression
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3));
}
}
public void testEstimateMemoryUsage() throws Exception {
createIndex("estimate-test-source-index");
BulkRequest bulkRequest =

View File

@ -20,14 +20,52 @@ include-tagged::{doc-tests-file}[{api}-request]
<1> Constructing a new evaluation request
<2> Reference to an existing index
<3> The query with which to select data from indices
<4> Kind of evaluation to perform
<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
<7> The remaining parameters are the metrics to be calculated based on the two fields described above.
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
<4> Evaluation to be performed
==== Evaluation
Evaluation to be performed.
Currently, supported evaluations include: +BinarySoftClassification+, +Classification+, +Regression+.
===== Binary soft classification
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-evaluation-softclassification]
--------------------------------------------------
<1> Constructing a new evaluation
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false.
<3> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> https://en.wikipedia.org/wiki/Precision_and_recall#Precision[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
<6> https://en.wikipedia.org/wiki/Precision_and_recall#Recall[Recall] calculated at thresholds: 0.5 and 0.7
<7> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
<8> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
===== Classification
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-evaluation-classification]
--------------------------------------------------
<1> Constructing a new evaluation
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> Multiclass confusion matrix of size 3
===== Regression
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-evaluation-regression]
--------------------------------------------------
<1> Constructing a new evaluation
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) value for an example.
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
include::../execution.asciidoc[]
@ -41,7 +79,40 @@ The returned +{response}+ contains the requested evaluation metrics.
include-tagged::{doc-tests-file}[{api}-response]
--------------------------------------------------
<1> Fetching all the calculated metrics results
<2> Fetching precision metric by name
<3> Fetching precision at a given (0.4) threshold
<4> Fetching confusion matrix metric by name
<5> Fetching confusion matrix at a given (0.5) threshold
==== Results
===== Binary soft classification
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-results-softclassification]
--------------------------------------------------
<1> Fetching precision metric by name
<2> Fetching precision at a given (0.4) threshold
<3> Fetching confusion matrix metric by name
<4> Fetching confusion matrix at a given (0.5) threshold
===== Classification
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-results-classification]
--------------------------------------------------
<1> Fetching multiclass confusion matrix metric by name
<2> Fetching the contents of the confusion matrix
<3> Fetching the number of classes that were not included in the matrix
===== Regression
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-results-regression]
--------------------------------------------------
<1> Fetching mean squared error metric by name
<2> Fetching the actual mean squared error value
<3> Fetching R squared metric by name
<4> Fetching the actual R squared value

View File

@ -76,7 +76,7 @@ include-tagged::{doc-tests-file}[{api}-dest-config]
==== Analysis
The analysis to be performed.
Currently, the supported analyses include : +OutlierDetection+, +Regression+.
Currently, the supported analyses include: +OutlierDetection+, +Classification+, +Regression+.
===== Outlier detection
@ -101,6 +101,24 @@ include-tagged::{doc-tests-file}[{api}-outlier-detection-customized]
<6> The proportion of the data set that is assumed to be outlying prior to outlier detection
<7> Whether to apply standardization to feature values
===== Classification
+Classification+ analysis requires to set which is the +dependent_variable+ and
has a number of other optional parameters:
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-classification]
--------------------------------------------------
<1> Constructing a new Classification builder object with the required dependent variable
<2> The lambda regularization parameter. A non-negative double.
<3> The gamma regularization parameter. A non-negative double.
<4> The applied shrinkage. A double in [0.001, 1].
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
===== Regression
+Regression+ analysis requires to set which is the +dependent_variable+ and