This commit is contained in:
parent
69fc715bc3
commit
1a42e37070
|
@ -126,8 +126,8 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
|||
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
|
@ -1297,6 +1297,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setIndex("put-test-dest-index")
|
||||
.build())
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
|
||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||
.setTrainingPercent(80.0)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
|
@ -1331,6 +1332,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setIndex("put-test-dest-index")
|
||||
.build())
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
|
||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||
.setTrainingPercent(80.0)
|
||||
.setNumTopClasses(1)
|
||||
.build())
|
||||
|
|
|
@ -92,7 +92,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
@ -113,6 +113,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return dependentVariable;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
||||
public int getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
||||
|
@ -89,6 +89,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return dependentVariable;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
||||
public double getTrainingPercent() {
|
||||
return trainingPercent;
|
||||
}
|
||||
|
|
|
@ -73,6 +73,14 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetNumTopClasses() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
|
|
@ -19,6 +19,8 @@ import static org.hamcrest.Matchers.nullValue;
|
|||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
return Regression.fromXContent(parser, false);
|
||||
|
@ -42,32 +44,45 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
return Regression::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsNull() {
|
||||
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsBoundary() {
|
||||
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
|
|
@ -1470,6 +1470,7 @@ setup:
|
|||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 60.3
|
||||
}
|
||||
}}
|
||||
|
@ -1809,6 +1810,7 @@ setup:
|
|||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 60.3,
|
||||
"num_top_classes": 2
|
||||
}
|
||||
|
@ -1844,6 +1846,7 @@ setup:
|
|||
- match: { analysis: {
|
||||
"regression":{
|
||||
"dependent_variable": "foo",
|
||||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 100.0
|
||||
}
|
||||
}}
|
||||
|
|
Loading…
Reference in New Issue