This commit is contained in:
parent
e0489fc328
commit
28f68fa221
|
@ -48,6 +48,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
|
||||
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
|
@ -61,7 +62,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
(Integer) a[4],
|
||||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7]));
|
||||
(Double) a[7],
|
||||
(Integer) a[8]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
|
@ -72,6 +74,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
|
@ -82,10 +85,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final Double featureBagFraction;
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Integer numTopClasses;
|
||||
|
||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
|
@ -94,6 +98,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.featureBagFraction = featureBagFraction;
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.numTopClasses = numTopClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,6 +138,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return trainingPercent;
|
||||
}
|
||||
|
||||
public Integer getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -158,6 +167,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (trainingPercent != null) {
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
}
|
||||
if (numTopClasses != null) {
|
||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -165,7 +177,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
trainingPercent, numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -180,7 +192,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent);
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -197,6 +210,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private Double featureBagFraction;
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Integer numTopClasses;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
@ -237,9 +251,14 @@ public class Classification implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopClasses(Integer numTopClasses) {
|
||||
this.numTopClasses = numTopClasses;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Classification build() {
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
trainingPercent, numTopClasses);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1296,8 +1296,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setDest(DataFrameAnalyticsDest.builder()
|
||||
.setIndex("put-test-dest-index")
|
||||
.build())
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
|
||||
.builder("my_dependent_variable")
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
|
||||
.setTrainingPercent(80.0)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
|
@ -1331,9 +1330,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setDest(DataFrameAnalyticsDest.builder()
|
||||
.setIndex("put-test-dest-index")
|
||||
.build())
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
|
||||
.builder("my_dependent_variable")
|
||||
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
|
||||
.setTrainingPercent(80.0)
|
||||
.setNumTopClasses(1)
|
||||
.build())
|
||||
.setDescription("this is a classification")
|
||||
.build();
|
||||
|
|
|
@ -2951,6 +2951,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.setNumTopClasses(1) // <9>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-classification
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -118,6 +118,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<9> The number of top classes to be reported in the results. Defaults to 2.
|
||||
|
||||
===== Regression
|
||||
|
||||
|
|
|
@ -67,6 +67,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
.flatMap(Set::stream)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
/**
|
||||
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
|
||||
* This way the user can see if the prediction was made with confidence they need.
|
||||
*/
|
||||
private static final int DEFAULT_NUM_TOP_CLASSES = 2;
|
||||
|
||||
private final String dependentVariable;
|
||||
private final BoostedTreeParams boostedTreeParams;
|
||||
private final String predictionFieldName;
|
||||
|
@ -87,7 +93,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
|
||||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
||||
|
@ -107,6 +113,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return dependentVariable;
|
||||
}
|
||||
|
||||
public int getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
||||
public double getTrainingPercent() {
|
||||
return trainingPercent;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ import static org.hamcrest.Matchers.nullValue;
|
|||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
return Classification.fromXContent(parser, false);
|
||||
|
@ -43,32 +45,68 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
return Classification::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsNull() {
|
||||
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsBoundary() {
|
||||
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
}
|
||||
|
||||
public void testGetNumTopClasses() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
||||
// Boundary condition: num_top_classes == 0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||
|
||||
// Boundary condition: num_top_classes == 1000
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||
|
||||
// num_top_classes == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -120,7 +120,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
|
|
@ -1810,7 +1810,7 @@ setup:
|
|||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3,
|
||||
"num_top_classes": 0
|
||||
"num_top_classes": 2
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
|
Loading…
Reference in New Issue