[7.x][ML] Adds the class_assignment_objective parameter to classification (#53552)
Adds a new parameter for classification that enables choosing whether to assign labels to maximise accuracy or to maximise the minimum class recall. Fixes #52427.
This commit is contained in:
parent
b1d589f276
commit
690099553c
|
@ -22,10 +22,12 @@ import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public class Classification implements DataFrameAnalysis {
|
public class Classification implements DataFrameAnalysis {
|
||||||
|
@ -49,6 +51,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
|
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
||||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
|
|
||||||
|
@ -67,7 +70,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
(String) a[7],
|
(String) a[7],
|
||||||
(Double) a[8],
|
(Double) a[8],
|
||||||
(Integer) a[9],
|
(Integer) a[9],
|
||||||
(Long) a[10]));
|
(Long) a[10],
|
||||||
|
(ClassAssignmentObjective) a[11]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||||
|
@ -81,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
|
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
|
||||||
|
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||||
|
return ClassAssignmentObjective.fromString(p.text());
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
|
||||||
|
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String dependentVariable;
|
private final String dependentVariable;
|
||||||
|
@ -92,13 +102,15 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private final Integer numTopFeatureImportanceValues;
|
private final Integer numTopFeatureImportanceValues;
|
||||||
private final String predictionFieldName;
|
private final String predictionFieldName;
|
||||||
private final Double trainingPercent;
|
private final Double trainingPercent;
|
||||||
|
private final ClassAssignmentObjective classAssignmentObjective;
|
||||||
private final Integer numTopClasses;
|
private final Integer numTopClasses;
|
||||||
private final Long randomizeSeed;
|
private final Long randomizeSeed;
|
||||||
|
|
||||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
|
||||||
|
@Nullable ClassAssignmentObjective classAssignmentObjective) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
|
@ -108,6 +120,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
this.predictionFieldName = predictionFieldName;
|
this.predictionFieldName = predictionFieldName;
|
||||||
this.trainingPercent = trainingPercent;
|
this.trainingPercent = trainingPercent;
|
||||||
|
this.classAssignmentObjective = classAssignmentObjective;
|
||||||
this.numTopClasses = numTopClasses;
|
this.numTopClasses = numTopClasses;
|
||||||
this.randomizeSeed = randomizeSeed;
|
this.randomizeSeed = randomizeSeed;
|
||||||
}
|
}
|
||||||
|
@ -157,6 +170,10 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return randomizeSeed;
|
return randomizeSeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ClassAssignmentObjective getClassAssignmentObjective() {
|
||||||
|
return classAssignmentObjective;
|
||||||
|
}
|
||||||
|
|
||||||
public Integer getNumTopClasses() {
|
public Integer getNumTopClasses() {
|
||||||
return numTopClasses;
|
return numTopClasses;
|
||||||
}
|
}
|
||||||
|
@ -192,6 +209,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
if (randomizeSeed != null) {
|
if (randomizeSeed != null) {
|
||||||
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
|
||||||
}
|
}
|
||||||
|
if (classAssignmentObjective != null) {
|
||||||
|
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
||||||
|
}
|
||||||
if (numTopClasses != null) {
|
if (numTopClasses != null) {
|
||||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||||
}
|
}
|
||||||
|
@ -202,7 +222,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||||
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
|
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -220,7 +240,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||||
&& Objects.equals(numTopClasses, that.numTopClasses);
|
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||||
|
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -228,6 +249,19 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return Strings.toString(this);
|
return Strings.toString(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public enum ClassAssignmentObjective {
|
||||||
|
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
|
||||||
|
|
||||||
|
public static ClassAssignmentObjective fromString(String value) {
|
||||||
|
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return name().toLowerCase(Locale.ROOT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
private String dependentVariable;
|
private String dependentVariable;
|
||||||
private Double lambda;
|
private Double lambda;
|
||||||
|
@ -240,6 +274,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private Double trainingPercent;
|
private Double trainingPercent;
|
||||||
private Integer numTopClasses;
|
private Integer numTopClasses;
|
||||||
private Long randomizeSeed;
|
private Long randomizeSeed;
|
||||||
|
private ClassAssignmentObjective classAssignmentObjective;
|
||||||
|
|
||||||
private Builder(String dependentVariable) {
|
private Builder(String dependentVariable) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
|
@ -295,9 +330,15 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) {
|
||||||
|
this.classAssignmentObjective = classAssignmentObjective;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Classification build() {
|
public Classification build() {
|
||||||
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
|
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
|
||||||
|
classAssignmentObjective);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||||
.setTrainingPercent(80.0)
|
.setTrainingPercent(80.0)
|
||||||
.setRandomizeSeed(42L)
|
.setRandomizeSeed(42L)
|
||||||
|
.setClassAssignmentObjective(
|
||||||
|
org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)
|
||||||
.setNumTopClasses(1)
|
.setNumTopClasses(1)
|
||||||
.setLambda(1.0)
|
.setLambda(1.0)
|
||||||
.setGamma(1.0)
|
.setGamma(1.0)
|
||||||
|
|
|
@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
|
||||||
import org.elasticsearch.client.ml.datafeed.DatafeedStats;
|
import org.elasticsearch.client.ml.datafeed.DatafeedStats;
|
||||||
import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
|
import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
|
||||||
import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
|
import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.Classification;
|
||||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
|
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
|
@ -2969,7 +2970,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
// end::put-data-frame-analytics-outlier-detection-customized
|
// end::put-data-frame-analytics-outlier-detection-customized
|
||||||
|
|
||||||
// tag::put-data-frame-analytics-classification
|
// tag::put-data-frame-analytics-classification
|
||||||
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
|
DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1>
|
||||||
.setLambda(1.0) // <2>
|
.setLambda(1.0) // <2>
|
||||||
.setGamma(5.5) // <3>
|
.setGamma(5.5) // <3>
|
||||||
.setEta(5.5) // <4>
|
.setEta(5.5) // <4>
|
||||||
|
@ -2979,7 +2980,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.setPredictionFieldName("my_prediction_field_name") // <8>
|
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||||
.setTrainingPercent(50.0) // <9>
|
.setTrainingPercent(50.0) // <9>
|
||||||
.setRandomizeSeed(1234L) // <10>
|
.setRandomizeSeed(1234L) // <10>
|
||||||
.setNumTopClasses(1) // <11>
|
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
|
||||||
|
.setNumTopClasses(1) // <12>
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-classification
|
// end::put-data-frame-analytics-classification
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||||
|
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
||||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
||||||
<8> The name of the prediction field in the results object.
|
<8> The name of the prediction field in the results object.
|
||||||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
<11> The number of top classes to be reported in the results. Defaults to 2.
|
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
||||||
|
<12> The number of top classes to be reported in the results. Defaults to 2.
|
||||||
|
|
||||||
===== Regression
|
===== Regression
|
||||||
|
|
||||||
|
|
|
@ -138,6 +138,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
|
||||||
(Optional, double)
|
(Optional, double)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
|
||||||
|
|
||||||
|
`analysis`.`classification`.`class_assignment_objective`::::
|
||||||
|
(Optional, string)
|
||||||
|
include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]
|
||||||
|
|
||||||
`analysis`.`classification`.`num_top_classes`::::
|
`analysis`.`classification`.`num_top_classes`::::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]
|
||||||
|
|
|
@ -334,6 +334,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode]
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
|
||||||
end::chunking-config[]
|
end::chunking-config[]
|
||||||
|
|
||||||
|
tag::class-assignment-objective[]
|
||||||
|
Defines the objective to optimize when assigning class labels. Available
|
||||||
|
objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing
|
||||||
|
accuracy class labels are chosen to maximize the number of correct predictions.
|
||||||
|
When maximizing minimum recall labels are chosen to maximize the minimum recall
|
||||||
|
for any class. Defaults to maximize_minimum_recall.
|
||||||
|
end::class-assignment-objective[]
|
||||||
|
|
||||||
tag::compute-feature-influence[]
|
tag::compute-feature-influence[]
|
||||||
If `true`, the feature influence calculation is enabled. Defaults to `true`.
|
If `true`, the feature influence calculation is enabled. Defaults to `true`.
|
||||||
end::compute-feature-influence[]
|
end::compute-feature-influence[]
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||||
|
@ -21,6 +22,7 @@ import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -37,6 +39,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
|
|
||||||
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
|
||||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||||
|
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
||||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
|
@ -54,12 +57,19 @@ public class Classification implements DataFrameAnalysis {
|
||||||
(String) a[0],
|
(String) a[0],
|
||||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||||
(String) a[7],
|
(String) a[7],
|
||||||
(Integer) a[8],
|
(ClassAssignmentObjective) a[8],
|
||||||
(Double) a[9],
|
(Integer) a[9],
|
||||||
(Long) a[10]));
|
(Double) a[10],
|
||||||
|
(Long) a[11]));
|
||||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||||
BoostedTreeParams.declareFields(parser);
|
BoostedTreeParams.declareFields(parser);
|
||||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
parser.declareField(optionalConstructorArg(), p -> {
|
||||||
|
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||||
|
return ClassAssignmentObjective.fromString(p.text());
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
|
||||||
|
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
|
||||||
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||||
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
|
@ -91,6 +101,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private final String dependentVariable;
|
private final String dependentVariable;
|
||||||
private final BoostedTreeParams boostedTreeParams;
|
private final BoostedTreeParams boostedTreeParams;
|
||||||
private final String predictionFieldName;
|
private final String predictionFieldName;
|
||||||
|
private final ClassAssignmentObjective classAssignmentObjective;
|
||||||
private final int numTopClasses;
|
private final int numTopClasses;
|
||||||
private final double trainingPercent;
|
private final double trainingPercent;
|
||||||
private final long randomizeSeed;
|
private final long randomizeSeed;
|
||||||
|
@ -98,6 +109,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
public Classification(String dependentVariable,
|
public Classification(String dependentVariable,
|
||||||
BoostedTreeParams boostedTreeParams,
|
BoostedTreeParams boostedTreeParams,
|
||||||
@Nullable String predictionFieldName,
|
@Nullable String predictionFieldName,
|
||||||
|
@Nullable ClassAssignmentObjective classAssignmentObjective,
|
||||||
@Nullable Integer numTopClasses,
|
@Nullable Integer numTopClasses,
|
||||||
@Nullable Double trainingPercent,
|
@Nullable Double trainingPercent,
|
||||||
@Nullable Long randomizeSeed) {
|
@Nullable Long randomizeSeed) {
|
||||||
|
@ -110,19 +122,26 @@ public class Classification implements DataFrameAnalysis {
|
||||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||||
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
|
||||||
|
this.classAssignmentObjective = classAssignmentObjective == null ?
|
||||||
|
ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective;
|
||||||
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
|
||||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(String dependentVariable) {
|
public Classification(String dependentVariable) {
|
||||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null);
|
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(StreamInput in) throws IOException {
|
public Classification(StreamInput in) throws IOException {
|
||||||
dependentVariable = in.readString();
|
dependentVariable = in.readString();
|
||||||
boostedTreeParams = new BoostedTreeParams(in);
|
boostedTreeParams = new BoostedTreeParams(in);
|
||||||
predictionFieldName = in.readOptionalString();
|
predictionFieldName = in.readOptionalString();
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
|
classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class);
|
||||||
|
} else {
|
||||||
|
classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL;
|
||||||
|
}
|
||||||
numTopClasses = in.readOptionalVInt();
|
numTopClasses = in.readOptionalVInt();
|
||||||
trainingPercent = in.readDouble();
|
trainingPercent = in.readDouble();
|
||||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
@ -144,6 +163,10 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return predictionFieldName;
|
return predictionFieldName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ClassAssignmentObjective getClassAssignmentObjective() {
|
||||||
|
return classAssignmentObjective;
|
||||||
|
}
|
||||||
|
|
||||||
public int getNumTopClasses() {
|
public int getNumTopClasses() {
|
||||||
return numTopClasses;
|
return numTopClasses;
|
||||||
}
|
}
|
||||||
|
@ -166,6 +189,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
out.writeString(dependentVariable);
|
out.writeString(dependentVariable);
|
||||||
boostedTreeParams.writeTo(out);
|
boostedTreeParams.writeTo(out);
|
||||||
out.writeOptionalString(predictionFieldName);
|
out.writeOptionalString(predictionFieldName);
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
|
out.writeEnum(classAssignmentObjective);
|
||||||
|
}
|
||||||
out.writeOptionalVInt(numTopClasses);
|
out.writeOptionalVInt(numTopClasses);
|
||||||
out.writeDouble(trainingPercent);
|
out.writeDouble(trainingPercent);
|
||||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
@ -180,6 +206,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||||
boostedTreeParams.toXContent(builder, params);
|
boostedTreeParams.toXContent(builder, params);
|
||||||
|
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
||||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||||
if (predictionFieldName != null) {
|
if (predictionFieldName != null) {
|
||||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||||
|
@ -197,6 +224,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
Map<String, Object> params = new HashMap<>();
|
Map<String, Object> params = new HashMap<>();
|
||||||
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
|
||||||
params.putAll(boostedTreeParams.getParams());
|
params.putAll(boostedTreeParams.getParams());
|
||||||
|
params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
||||||
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||||
if (predictionFieldName != null) {
|
if (predictionFieldName != null) {
|
||||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||||
|
@ -305,6 +333,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return Objects.equals(dependentVariable, that.dependentVariable)
|
return Objects.equals(dependentVariable, that.dependentVariable)
|
||||||
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
&& Objects.equals(boostedTreeParams, that.boostedTreeParams)
|
||||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
|
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||||
&& trainingPercent == that.trainingPercent
|
&& trainingPercent == that.trainingPercent
|
||||||
&& randomizeSeed == that.randomizeSeed;
|
&& randomizeSeed == that.randomizeSeed;
|
||||||
|
@ -312,6 +341,20 @@ public class Classification implements DataFrameAnalysis {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed);
|
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||||
|
numTopClasses, trainingPercent, randomizeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ClassAssignmentObjective {
|
||||||
|
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
|
||||||
|
|
||||||
|
public static ClassAssignmentObjective fromString(String value) {
|
||||||
|
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return name().toLowerCase(Locale.ROOT);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -318,6 +318,7 @@ public final class ReservedFieldNames {
|
||||||
Classification.NAME.getPreferredName(),
|
Classification.NAME.getPreferredName(),
|
||||||
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
||||||
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||||
|
Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
|
||||||
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
||||||
Classification.TRAINING_PERCENT.getPreferredName(),
|
Classification.TRAINING_PERCENT.getPreferredName(),
|
||||||
BoostedTreeParams.LAMBDA.getPreferredName(),
|
BoostedTreeParams.LAMBDA.getPreferredName(),
|
||||||
|
|
|
@ -43,6 +43,9 @@
|
||||||
"max_trees" : {
|
"max_trees" : {
|
||||||
"type" : "integer"
|
"type" : "integer"
|
||||||
},
|
},
|
||||||
|
"class_assignment_objective" : {
|
||||||
|
"type" : "keyword"
|
||||||
|
},
|
||||||
"num_top_classes" : {
|
"num_top_classes" : {
|
||||||
"type" : "integer"
|
"type" : "integer"
|
||||||
},
|
},
|
||||||
|
|
|
@ -150,12 +150,14 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
||||||
bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
|
bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
|
||||||
bwcClassification.getBoostedTreeParams(),
|
bwcClassification.getBoostedTreeParams(),
|
||||||
bwcClassification.getPredictionFieldName(),
|
bwcClassification.getPredictionFieldName(),
|
||||||
|
bwcClassification.getClassAssignmentObjective(),
|
||||||
bwcClassification.getNumTopClasses(),
|
bwcClassification.getNumTopClasses(),
|
||||||
bwcClassification.getTrainingPercent(),
|
bwcClassification.getTrainingPercent(),
|
||||||
42L);
|
42L);
|
||||||
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
||||||
testClassification.getBoostedTreeParams(),
|
testClassification.getBoostedTreeParams(),
|
||||||
testClassification.getPredictionFieldName(),
|
testClassification.getPredictionFieldName(),
|
||||||
|
testClassification.getClassAssignmentObjective(),
|
||||||
testClassification.getNumTopClasses(),
|
testClassification.getNumTopClasses(),
|
||||||
testClassification.getTrainingPercent(),
|
testClassification.getTrainingPercent(),
|
||||||
42L);
|
42L);
|
||||||
|
|
|
@ -55,17 +55,20 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
String dependentVariableName = randomAlphaOfLength(10);
|
String dependentVariableName = randomAlphaOfLength(10);
|
||||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||||
|
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
||||||
|
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
||||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
||||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||||
randomizeSeed);
|
numTopClasses, trainingPercent, randomizeSeed);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Classification mutateForVersion(Classification instance, Version version) {
|
public static Classification mutateForVersion(Classification instance, Version version) {
|
||||||
return new Classification(instance.getDependentVariable(),
|
return new Classification(instance.getDependentVariable(),
|
||||||
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
||||||
instance.getPredictionFieldName(),
|
instance.getPredictionFieldName(),
|
||||||
|
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
|
||||||
instance.getNumTopClasses(),
|
instance.getNumTopClasses(),
|
||||||
instance.getTrainingPercent(),
|
instance.getTrainingPercent(),
|
||||||
instance.getRandomizeSeed());
|
instance.getRandomizeSeed());
|
||||||
|
@ -81,12 +84,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
|
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
|
||||||
bwcSerializedObject.getBoostedTreeParams(),
|
bwcSerializedObject.getBoostedTreeParams(),
|
||||||
bwcSerializedObject.getPredictionFieldName(),
|
bwcSerializedObject.getPredictionFieldName(),
|
||||||
|
bwcSerializedObject.getClassAssignmentObjective(),
|
||||||
bwcSerializedObject.getNumTopClasses(),
|
bwcSerializedObject.getNumTopClasses(),
|
||||||
bwcSerializedObject.getTrainingPercent(),
|
bwcSerializedObject.getTrainingPercent(),
|
||||||
42L);
|
42L);
|
||||||
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
||||||
testInstance.getBoostedTreeParams(),
|
testInstance.getBoostedTreeParams(),
|
||||||
testInstance.getPredictionFieldName(),
|
testInstance.getPredictionFieldName(),
|
||||||
|
testInstance.getClassAssignmentObjective(),
|
||||||
testInstance.getNumTopClasses(),
|
testInstance.getNumTopClasses(),
|
||||||
testInstance.getTrainingPercent(),
|
testInstance.getTrainingPercent(),
|
||||||
42L);
|
42L);
|
||||||
|
@ -100,71 +105,85 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetPredictionFieldName() {
|
public void testGetPredictionFieldName() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
||||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||||
|
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
|
||||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testClassAssignmentObjective() {
|
||||||
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||||
|
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
|
||||||
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
||||||
|
|
||||||
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||||
|
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
|
||||||
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||||
|
|
||||||
|
// class_assignment_objective == null, default applied
|
||||||
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
||||||
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetNumTopClasses() {
|
public void testGetNumTopClasses() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||||
|
|
||||||
// Boundary condition: num_top_classes == 0
|
// Boundary condition: num_top_classes == 0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||||
|
|
||||||
// Boundary condition: num_top_classes == 1000
|
// Boundary condition: num_top_classes == 1000
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||||
|
|
||||||
// num_top_classes == null, default applied
|
// num_top_classes == null, default applied
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetTrainingPercent() {
|
public void testGetTrainingPercent() {
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 1.0
|
// Boundary condition: training_percent == 1.0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||||
|
|
||||||
// Boundary condition: training_percent == 100.0
|
// Boundary condition: training_percent == 100.0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||||
|
|
||||||
// training_percent == null, default applied
|
// training_percent == null, default applied
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
|
||||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,6 +196,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
new Classification("foo").getParams(extractedFields),
|
new Classification("foo").getParams(extractedFields),
|
||||||
Matchers.<Map<String, Object>>allOf(
|
Matchers.<Map<String, Object>>allOf(
|
||||||
hasEntry("dependent_variable", "foo"),
|
hasEntry("dependent_variable", "foo"),
|
||||||
|
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||||
hasEntry("num_top_classes", 2),
|
hasEntry("num_top_classes", 2),
|
||||||
hasEntry("prediction_field_name", "foo_prediction"),
|
hasEntry("prediction_field_name", "foo_prediction"),
|
||||||
hasEntry("prediction_field_type", "bool")));
|
hasEntry("prediction_field_type", "bool")));
|
||||||
|
@ -184,6 +204,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
new Classification("bar").getParams(extractedFields),
|
new Classification("bar").getParams(extractedFields),
|
||||||
Matchers.<Map<String, Object>>allOf(
|
Matchers.<Map<String, Object>>allOf(
|
||||||
hasEntry("dependent_variable", "bar"),
|
hasEntry("dependent_variable", "bar"),
|
||||||
|
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||||
hasEntry("num_top_classes", 2),
|
hasEntry("num_top_classes", 2),
|
||||||
hasEntry("prediction_field_name", "bar_prediction"),
|
hasEntry("prediction_field_name", "bar_prediction"),
|
||||||
hasEntry("prediction_field_type", "int")));
|
hasEntry("prediction_field_type", "int")));
|
||||||
|
@ -191,6 +212,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
new Classification("baz").getParams(extractedFields),
|
new Classification("baz").getParams(extractedFields),
|
||||||
Matchers.<Map<String, Object>>allOf(
|
Matchers.<Map<String, Object>>allOf(
|
||||||
hasEntry("dependent_variable", "baz"),
|
hasEntry("dependent_variable", "baz"),
|
||||||
|
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||||
hasEntry("num_top_classes", 2),
|
hasEntry("num_top_classes", 2),
|
||||||
hasEntry("prediction_field_name", "baz_prediction"),
|
hasEntry("prediction_field_name", "baz_prediction"),
|
||||||
hasEntry("prediction_field_type", "string")));
|
hasEntry("prediction_field_type", "string")));
|
||||||
|
|
|
@ -89,6 +89,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
null));
|
null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
@ -190,7 +191,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null));
|
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -438,7 +439,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
|
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null));
|
||||||
registerAnalytics(firstJob);
|
registerAnalytics(firstJob);
|
||||||
putAnalytics(firstJob);
|
putAnalytics(firstJob);
|
||||||
|
|
||||||
|
@ -447,7 +448,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
||||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed));
|
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed));
|
||||||
|
|
||||||
registerAnalytics(secondJob);
|
registerAnalytics(secondJob);
|
||||||
putAnalytics(secondJob);
|
putAnalytics(secondJob);
|
||||||
|
|
|
@ -1834,6 +1834,7 @@ setup:
|
||||||
"eta": 0.5,
|
"eta": 0.5,
|
||||||
"max_trees": 400,
|
"max_trees": 400,
|
||||||
"feature_bag_fraction": 0.3,
|
"feature_bag_fraction": 0.3,
|
||||||
|
"class_assignment_objective": "maximize_accuracy",
|
||||||
"training_percent": 60.3,
|
"training_percent": 60.3,
|
||||||
"randomize_seed": 24
|
"randomize_seed": 24
|
||||||
}
|
}
|
||||||
|
@ -1853,6 +1854,7 @@ setup:
|
||||||
"prediction_field_name": "foo_prediction",
|
"prediction_field_name": "foo_prediction",
|
||||||
"training_percent": 60.3,
|
"training_percent": 60.3,
|
||||||
"randomize_seed": 24,
|
"randomize_seed": 24,
|
||||||
|
"class_assignment_objective": "maximize_accuracy",
|
||||||
"num_top_classes": 2
|
"num_top_classes": 2
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
@ -1896,6 +1898,7 @@ setup:
|
||||||
"prediction_field_name": "foo_prediction",
|
"prediction_field_name": "foo_prediction",
|
||||||
"training_percent": 100.0,
|
"training_percent": 100.0,
|
||||||
"randomize_seed": 24,
|
"randomize_seed": 24,
|
||||||
|
"class_assignment_objective": "maximize_minimum_recall",
|
||||||
"num_top_classes": 2
|
"num_top_classes": 2
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
|
Loading…
Reference in New Issue