mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-24 17:09:48 +00:00
This commit is contained in:
parent
e60221d2bd
commit
d210bfa888
@ -1776,7 +1776,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
new EvaluateDataFrameRequest(
|
||||
regressionIndex,
|
||||
null,
|
||||
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
@ -1933,7 +1933,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
private static final String probabilityRegression = "regression_prob";
|
||||
private static final String predictedRegression = "regression_predicted";
|
||||
|
||||
private static XContentBuilder mappingForRegression() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
@ -1941,17 +1941,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
.startObject(actualRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.startObject(probabilityRegression)
|
||||
.startObject(predictedRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForRegression(String indexName, double act, double p) {
|
||||
private static IndexRequest docForRegression(String indexName, double actualValue, double predictedValue) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
|
||||
.source(XContentType.JSON, actualRegression, actualValue, predictedRegression, predictedValue);
|
||||
}
|
||||
|
||||
private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
|
||||
|
@ -139,8 +139,11 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
@ -2821,7 +2824,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
List<DataFrameAnalyticsConfig> configs = response.getAnalytics();
|
||||
// end::get-data-frame-analytics-response
|
||||
|
||||
assertThat(configs.size(), equalTo(1));
|
||||
assertThat(configs, hasSize(1));
|
||||
}
|
||||
{
|
||||
GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config");
|
||||
@ -2871,7 +2874,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
List<DataFrameAnalyticsStats> stats = response.getAnalyticsStats();
|
||||
// end::get-data-frame-analytics-stats-response
|
||||
|
||||
assertThat(stats.size(), equalTo(1));
|
||||
assertThat(stats, hasSize(1));
|
||||
}
|
||||
{
|
||||
GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config");
|
||||
@ -2939,8 +2942,20 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
.build();
|
||||
// end::put-data-frame-analytics-outlier-detection-customized
|
||||
|
||||
// tag::put-data-frame-analytics-classification
|
||||
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
|
||||
.setLambda(1.0) // <2>
|
||||
.setGamma(5.5) // <3>
|
||||
.setEta(5.5) // <4>
|
||||
.setMaximumNumberTrees(50) // <5>
|
||||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-classification
|
||||
|
||||
// tag::put-data-frame-analytics-regression
|
||||
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
|
||||
DataFrameAnalysis regression = org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") // <1>
|
||||
.setLambda(1.0) // <2>
|
||||
.setGamma(5.5) // <3>
|
||||
.setEta(5.5) // <4>
|
||||
@ -3209,18 +3224,24 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
{
|
||||
// tag::evaluate-data-frame-evaluation-softclassification
|
||||
Evaluation evaluation =
|
||||
new BinarySoftClassification( // <1>
|
||||
"label", // <2>
|
||||
"p", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
|
||||
RecallMetric.at(0.5, 0.7), // <6>
|
||||
ConfusionMatrixMetric.at(0.5), // <7>
|
||||
AucRocMetric.withCurve()); // <8>
|
||||
// end::evaluate-data-frame-evaluation-softclassification
|
||||
|
||||
// tag::evaluate-data-frame-request
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
|
||||
indexName, // <2>
|
||||
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
|
||||
new BinarySoftClassification( // <4>
|
||||
"label", // <5>
|
||||
"p", // <6>
|
||||
// Evaluation metrics // <7>
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
|
||||
RecallMetric.at(0.5, 0.7), // <9>
|
||||
ConfusionMatrixMetric.at(0.5), // <10>
|
||||
AucRocMetric.withCurve())); // <11>
|
||||
EvaluateDataFrameRequest request =
|
||||
new EvaluateDataFrameRequest( // <1>
|
||||
indexName, // <2>
|
||||
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
|
||||
evaluation); // <4>
|
||||
// end::evaluate-data-frame-request
|
||||
|
||||
// tag::evaluate-data-frame-execute
|
||||
@ -3229,16 +3250,18 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
|
||||
// tag::evaluate-data-frame-response
|
||||
List<EvaluationMetric.Result> metrics = response.getMetrics(); // <1>
|
||||
|
||||
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2>
|
||||
double precision = precisionResult.getScoreByThreshold("0.4"); // <3>
|
||||
|
||||
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4>
|
||||
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5>
|
||||
// end::evaluate-data-frame-response
|
||||
|
||||
// tag::evaluate-data-frame-results-softclassification
|
||||
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
|
||||
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
|
||||
|
||||
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
|
||||
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <4>
|
||||
// end::evaluate-data-frame-results-softclassification
|
||||
|
||||
assertThat(
|
||||
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
|
||||
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
|
||||
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
|
||||
assertThat(precision, closeTo(0.6, 1e-9));
|
||||
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
|
||||
@ -3284,6 +3307,140 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame_Classification() throws Exception {
|
||||
String indexName = "evaluate-classification-test-index";
|
||||
CreateIndexRequest createIndexRequest =
|
||||
new CreateIndexRequest(indexName)
|
||||
.mapping(XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject("actual_class")
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.startObject("predicted_class")
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject());
|
||||
BulkRequest bulkRequest =
|
||||
new BulkRequest(indexName)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
{
|
||||
// tag::evaluate-data-frame-evaluation-classification
|
||||
Evaluation evaluation =
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
|
||||
"actual_class", // <2>
|
||||
"predicted_class", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <5>
|
||||
// end::evaluate-data-frame-evaluation-classification
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
|
||||
|
||||
// tag::evaluate-data-frame-results-classification
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
|
||||
|
||||
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
|
||||
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
assertThat(
|
||||
confusionMatrix,
|
||||
equalTo(
|
||||
new HashMap<String, Map<String, Long>>() {{
|
||||
put("cat", new HashMap<String, Long>() {{
|
||||
put("cat", 3L);
|
||||
put("dog", 1L);
|
||||
put("ant", 0L);
|
||||
put("_other_", 1L);
|
||||
}});
|
||||
put("dog", new HashMap<String, Long>() {{
|
||||
put("cat", 1L);
|
||||
put("dog", 3L);
|
||||
put("ant", 0L);
|
||||
}});
|
||||
put("ant", new HashMap<String, Long>() {{
|
||||
put("cat", 1L);
|
||||
put("dog", 0L);
|
||||
put("ant", 0L);
|
||||
}});
|
||||
}}));
|
||||
assertThat(otherClassesCount, equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame_Regression() throws Exception {
|
||||
String indexName = "evaluate-classification-test-index";
|
||||
CreateIndexRequest createIndexRequest =
|
||||
new CreateIndexRequest(indexName)
|
||||
.mapping(XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject("actual_value")
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.startObject("predicted_value")
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject());
|
||||
BulkRequest bulkRequest =
|
||||
new BulkRequest(indexName)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 1.0)) // #0
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 0.9)) // #1
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.0, "predicted_value", 2.0)) // #2
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.5, "predicted_value", 1.4)) // #3
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.2, "predicted_value", 1.3)) // #4
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.7, "predicted_value", 2.0)) // #5
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.1, "predicted_value", 2.1)) // #6
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.7)) // #7
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 0.8, "predicted_value", 1.0)) // #8
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.4)); // #9
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
{
|
||||
// tag::evaluate-data-frame-evaluation-regression
|
||||
Evaluation evaluation =
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( // <1>
|
||||
"actual_value", // <2>
|
||||
"predicted_value", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new MeanSquaredErrorMetric(), // <5>
|
||||
new RSquaredMetric()); // <6>
|
||||
// end::evaluate-data-frame-evaluation-regression
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
|
||||
|
||||
// tag::evaluate-data-frame-results-regression
|
||||
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
|
||||
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
|
||||
|
||||
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
|
||||
double rSquared = rSquaredResult.getValue(); // <4>
|
||||
// end::evaluate-data-frame-results-regression
|
||||
|
||||
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
|
||||
assertThat(rSquared, closeTo(0.941, 1e-3));
|
||||
}
|
||||
}
|
||||
|
||||
public void testEstimateMemoryUsage() throws Exception {
|
||||
createIndex("estimate-test-source-index");
|
||||
BulkRequest bulkRequest =
|
||||
|
@ -20,14 +20,52 @@ include-tagged::{doc-tests-file}[{api}-request]
|
||||
<1> Constructing a new evaluation request
|
||||
<2> Reference to an existing index
|
||||
<3> The query with which to select data from indices
|
||||
<4> Kind of evaluation to perform
|
||||
<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
|
||||
<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
|
||||
<7> The remaining parameters are the metrics to be calculated based on the two fields described above.
|
||||
<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
|
||||
<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
|
||||
<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
|
||||
<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
|
||||
<4> Evaluation to be performed
|
||||
|
||||
==== Evaluation
|
||||
|
||||
Evaluation to be performed.
|
||||
Currently, supported evaluations include: +BinarySoftClassification+, +Classification+, +Regression+.
|
||||
|
||||
===== Binary soft classification
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-evaluation-softclassification]
|
||||
--------------------------------------------------
|
||||
<1> Constructing a new evaluation
|
||||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false.
|
||||
<3> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> https://en.wikipedia.org/wiki/Precision_and_recall#Precision[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
|
||||
<6> https://en.wikipedia.org/wiki/Precision_and_recall#Recall[Recall] calculated at thresholds: 0.5 and 0.7
|
||||
<7> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
|
||||
<8> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
|
||||
|
||||
===== Classification
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-evaluation-classification]
|
||||
--------------------------------------------------
|
||||
<1> Constructing a new evaluation
|
||||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> Multiclass confusion matrix of size 3
|
||||
|
||||
===== Regression
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-evaluation-regression]
|
||||
--------------------------------------------------
|
||||
<1> Constructing a new evaluation
|
||||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) value for an example.
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
|
||||
<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
@ -41,7 +79,40 @@ The returned +{response}+ contains the requested evaluation metrics.
|
||||
include-tagged::{doc-tests-file}[{api}-response]
|
||||
--------------------------------------------------
|
||||
<1> Fetching all the calculated metrics results
|
||||
<2> Fetching precision metric by name
|
||||
<3> Fetching precision at a given (0.4) threshold
|
||||
<4> Fetching confusion matrix metric by name
|
||||
<5> Fetching confusion matrix at a given (0.5) threshold
|
||||
|
||||
==== Results
|
||||
|
||||
===== Binary soft classification
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-results-softclassification]
|
||||
--------------------------------------------------
|
||||
|
||||
<1> Fetching precision metric by name
|
||||
<2> Fetching precision at a given (0.4) threshold
|
||||
<3> Fetching confusion matrix metric by name
|
||||
<4> Fetching confusion matrix at a given (0.5) threshold
|
||||
|
||||
===== Classification
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-results-classification]
|
||||
--------------------------------------------------
|
||||
|
||||
<1> Fetching multiclass confusion matrix metric by name
|
||||
<2> Fetching the contents of the confusion matrix
|
||||
<3> Fetching the number of classes that were not included in the matrix
|
||||
|
||||
===== Regression
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-results-regression]
|
||||
--------------------------------------------------
|
||||
|
||||
<1> Fetching mean squared error metric by name
|
||||
<2> Fetching the actual mean squared error value
|
||||
<3> Fetching R squared metric by name
|
||||
<4> Fetching the actual R squared value
|
@ -76,7 +76,7 @@ include-tagged::{doc-tests-file}[{api}-dest-config]
|
||||
==== Analysis
|
||||
|
||||
The analysis to be performed.
|
||||
Currently, the supported analyses include : +OutlierDetection+, +Regression+.
|
||||
Currently, the supported analyses include: +OutlierDetection+, +Classification+, +Regression+.
|
||||
|
||||
===== Outlier detection
|
||||
|
||||
@ -101,6 +101,24 @@ include-tagged::{doc-tests-file}[{api}-outlier-detection-customized]
|
||||
<6> The proportion of the data set that is assumed to be outlying prior to outlier detection
|
||||
<7> Whether to apply standardization to feature values
|
||||
|
||||
===== Classification
|
||||
|
||||
+Classification+ analysis requires to set which is the +dependent_variable+ and
|
||||
has a number of other optional parameters:
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-classification]
|
||||
--------------------------------------------------
|
||||
<1> Constructing a new Classification builder object with the required dependent variable
|
||||
<2> The lambda regularization parameter. A non-negative double.
|
||||
<3> The gamma regularization parameter. A non-negative double.
|
||||
<4> The applied shrinkage. A double in [0.001, 1].
|
||||
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
|
||||
===== Regression
|
||||
|
||||
+Regression+ analysis requires to set which is the +dependent_variable+ and
|
||||
|
Loading…
x
Reference in New Issue
Block a user