[7.x][ML] Adds the class_assignment_objective parameter to classification (#53552)

Adds a new parameter for classification that enables choosing whether to assign labels to
maximise accuracy or to maximise the minimum class recall.

Fixes #52427.
This commit is contained in:
Tom Veasey 2020-03-13 17:35:51 +00:00 committed by GitHub
parent b1d589f276
commit 690099553c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 166 additions and 32 deletions

View File

@ -22,10 +22,12 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
public class Classification implements DataFrameAnalysis {
@ -49,6 +51,7 @@ public class Classification implements DataFrameAnalysis {
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
@ -67,7 +70,8 @@ public class Classification implements DataFrameAnalysis {
(String) a[7],
(Double) a[8],
(Integer) a[9],
(Long) a[10]));
(Long) a[10],
(ClassAssignmentObjective) a[11]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@ -81,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ClassAssignmentObjective.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
}
private final String dependentVariable;
@ -92,13 +102,15 @@ public class Classification implements DataFrameAnalysis {
private final Integer numTopFeatureImportanceValues;
private final String predictionFieldName;
private final Double trainingPercent;
private final ClassAssignmentObjective classAssignmentObjective;
private final Integer numTopClasses;
private final Long randomizeSeed;
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
@Nullable ClassAssignmentObjective classAssignmentObjective) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -108,6 +120,7 @@ public class Classification implements DataFrameAnalysis {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.classAssignmentObjective = classAssignmentObjective;
this.numTopClasses = numTopClasses;
this.randomizeSeed = randomizeSeed;
}
@ -157,6 +170,10 @@ public class Classification implements DataFrameAnalysis {
return randomizeSeed;
}
public ClassAssignmentObjective getClassAssignmentObjective() {
return classAssignmentObjective;
}
public Integer getNumTopClasses() {
return numTopClasses;
}
@ -192,6 +209,9 @@ public class Classification implements DataFrameAnalysis {
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (classAssignmentObjective != null) {
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
}
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
@ -202,7 +222,7 @@ public class Classification implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
}
@Override
@ -220,7 +240,8 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(numTopClasses, that.numTopClasses);
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
}
@Override
@ -228,6 +249,19 @@ public class Classification implements DataFrameAnalysis {
return Strings.toString(this);
}
public enum ClassAssignmentObjective {
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
public static ClassAssignmentObjective fromString(String value) {
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
public static class Builder {
private String dependentVariable;
private Double lambda;
@ -240,6 +274,7 @@ public class Classification implements DataFrameAnalysis {
private Double trainingPercent;
private Integer numTopClasses;
private Long randomizeSeed;
private ClassAssignmentObjective classAssignmentObjective;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -295,9 +330,15 @@ public class Classification implements DataFrameAnalysis {
return this;
}
public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) {
this.classAssignmentObjective = classAssignmentObjective;
return this;
}
public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
classAssignmentObjective);
}
}
}

View File

@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.setClassAssignmentObjective(
org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)
.setNumTopClasses(1)
.setLambda(1.0)
.setGamma(1.0)

View File

@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
import org.elasticsearch.client.ml.datafeed.DatafeedStats;
import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
import org.elasticsearch.client.ml.dataframe.Classification;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
@ -2969,7 +2970,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// end::put-data-frame-analytics-outlier-detection-customized
// tag::put-data-frame-analytics-classification
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
@ -2979,7 +2980,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setPredictionFieldName("my_prediction_field_name") // <8>
.setTrainingPercent(50.0) // <9>
.setRandomizeSeed(1234L) // <10>
.setNumTopClasses(1) // <11>
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
.setNumTopClasses(1) // <12>
.build();
// end::put-data-frame-analytics-classification

View File

@ -36,6 +36,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.build();
}

View File

@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
<8> The name of the prediction field in the results object.
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The number of top classes to be reported in the results. Defaults to 2.
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
<12> The number of top classes to be reported in the results. Defaults to 2.
===== Regression

View File

@ -138,6 +138,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
(Optional, double)
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
`analysis`.`classification`.`class_assignment_objective`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]
`analysis`.`classification`.`num_top_classes`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]

View File

@ -334,6 +334,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode]
include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
end::chunking-config[]
tag::class-assignment-objective[]
Defines the objective to optimize when assigning class labels. Available
objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing
accuracy class labels are chosen to maximize the number of correct predictions.
When maximizing minimum recall labels are chosen to maximize the minimum recall
for any class. Defaults to maximize_minimum_recall.
end::class-assignment-objective[]
tag::compute-feature-influence[]
If `true`, the feature influence calculation is enabled. Defaults to `true`.
end::compute-feature-influence[]

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
@ -21,6 +22,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
@ -37,6 +39,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
@ -54,12 +57,19 @@ public class Classification implements DataFrameAnalysis {
(String) a[0],
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
(String) a[7],
(Integer) a[8],
(Double) a[9],
(Long) a[10]));
(ClassAssignmentObjective) a[8],
(Integer) a[9],
(Double) a[10],
(Long) a[11]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareField(optionalConstructorArg(), p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ClassAssignmentObjective.fromString(p.text());
}
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
@ -91,6 +101,7 @@ public class Classification implements DataFrameAnalysis {
private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
private final ClassAssignmentObjective classAssignmentObjective;
private final int numTopClasses;
private final double trainingPercent;
private final long randomizeSeed;
@ -98,6 +109,7 @@ public class Classification implements DataFrameAnalysis {
public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable ClassAssignmentObjective classAssignmentObjective,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed) {
@ -110,19 +122,26 @@ public class Classification implements DataFrameAnalysis {
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.classAssignmentObjective = classAssignmentObjective == null ?
ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
}
public Classification(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
}
public Classification(StreamInput in) throws IOException {
dependentVariable = in.readString();
boostedTreeParams = new BoostedTreeParams(in);
predictionFieldName = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class);
} else {
classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL;
}
numTopClasses = in.readOptionalVInt();
trainingPercent = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
@ -144,6 +163,10 @@ public class Classification implements DataFrameAnalysis {
return predictionFieldName;
}
public ClassAssignmentObjective getClassAssignmentObjective() {
return classAssignmentObjective;
}
public int getNumTopClasses() {
return numTopClasses;
}
@ -166,6 +189,9 @@ public class Classification implements DataFrameAnalysis {
out.writeString(dependentVariable);
boostedTreeParams.writeTo(out);
out.writeOptionalString(predictionFieldName);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeEnum(classAssignmentObjective);
}
out.writeOptionalVInt(numTopClasses);
out.writeDouble(trainingPercent);
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
@ -180,6 +206,7 @@ public class Classification implements DataFrameAnalysis {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
boostedTreeParams.toXContent(builder, params);
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
@ -197,6 +224,7 @@ public class Classification implements DataFrameAnalysis {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
@ -305,6 +333,7 @@ public class Classification implements DataFrameAnalysis {
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
@ -312,6 +341,20 @@ public class Classification implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed);
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed);
}
public enum ClassAssignmentObjective {
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
public static ClassAssignmentObjective fromString(String value) {
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
}
}

View File

@ -318,6 +318,7 @@ public final class ReservedFieldNames {
Classification.NAME.getPreferredName(),
Classification.DEPENDENT_VARIABLE.getPreferredName(),
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
Classification.NUM_TOP_CLASSES.getPreferredName(),
Classification.TRAINING_PERCENT.getPreferredName(),
BoostedTreeParams.LAMBDA.getPreferredName(),

View File

@ -43,6 +43,9 @@
"max_trees" : {
"type" : "integer"
},
"class_assignment_objective" : {
"type" : "keyword"
},
"num_top_classes" : {
"type" : "integer"
},

View File

@ -150,12 +150,14 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
bwcClassification.getBoostedTreeParams(),
bwcClassification.getPredictionFieldName(),
bwcClassification.getClassAssignmentObjective(),
bwcClassification.getNumTopClasses(),
bwcClassification.getTrainingPercent(),
42L);
testAnalysis = new Classification(testClassification.getDependentVariable(),
testClassification.getBoostedTreeParams(),
testClassification.getPredictionFieldName(),
testClassification.getClassAssignmentObjective(),
testClassification.getNumTopClasses(),
testClassification.getTrainingPercent(),
42L);

View File

@ -55,17 +55,20 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
null : randomFrom(Classification.ClassAssignmentObjective.values());
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
randomizeSeed);
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed);
}
public static Classification mutateForVersion(Classification instance, Version version) {
return new Classification(instance.getDependentVariable(),
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
instance.getPredictionFieldName(),
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
instance.getNumTopClasses(),
instance.getTrainingPercent(),
instance.getRandomizeSeed());
@ -81,12 +84,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
bwcSerializedObject.getBoostedTreeParams(),
bwcSerializedObject.getPredictionFieldName(),
bwcSerializedObject.getClassAssignmentObjective(),
bwcSerializedObject.getNumTopClasses(),
bwcSerializedObject.getTrainingPercent(),
42L);
Classification newInstance = new Classification(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
testInstance.getClassAssignmentObjective(),
testInstance.getNumTopClasses(),
testInstance.getTrainingPercent(),
42L);
@ -100,71 +105,85 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}
public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
assertThat(classification.getPredictionFieldName(), equalTo("result"));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testClassAssignmentObjective() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
// class_assignment_objective == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
}
public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(7));
// Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(0));
// Boundary condition: num_top_classes == 1000
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(1000));
// num_top_classes == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
assertThat(classification.getNumTopClasses(), equalTo(2));
}
public void testGetTrainingPercent() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
@ -177,6 +196,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
new Classification("foo").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "foo"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "foo_prediction"),
hasEntry("prediction_field_type", "bool")));
@ -184,6 +204,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
new Classification("bar").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "bar"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "bar_prediction"),
hasEntry("prediction_field_type", "int")));
@ -191,6 +212,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
new Classification("baz").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "baz"),
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "baz_prediction"),
hasEntry("prediction_field_type", "string")));

View File

@ -89,6 +89,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null));
registerAnalytics(config);
putAnalytics(config);
@ -190,7 +191,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null));
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null));
registerAnalytics(config);
putAnalytics(config);
@ -438,7 +439,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null));
registerAnalytics(firstJob);
putAnalytics(firstJob);
@ -447,7 +448,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed));
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed));
registerAnalytics(secondJob);
putAnalytics(secondJob);

View File

@ -1834,6 +1834,7 @@ setup:
"eta": 0.5,
"max_trees": 400,
"feature_bag_fraction": 0.3,
"class_assignment_objective": "maximize_accuracy",
"training_percent": 60.3,
"randomize_seed": 24
}
@ -1853,6 +1854,7 @@ setup:
"prediction_field_name": "foo_prediction",
"training_percent": 60.3,
"randomize_seed": 24,
"class_assignment_objective": "maximize_accuracy",
"num_top_classes": 2
}
}}
@ -1896,6 +1898,7 @@ setup:
"prediction_field_name": "foo_prediction",
"training_percent": 100.0,
"randomize_seed": 24,
"class_assignment_objective": "maximize_minimum_recall",
"num_top_classes": 2
}
}}