diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 6fc533c0f72..48c75e6e34e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -22,10 +22,12 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.Locale; import java.util.Objects; public class Classification implements DataFrameAnalysis { @@ -49,6 +51,7 @@ public class Classification implements DataFrameAnalysis { static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -67,7 +70,8 @@ public class Classification implements DataFrameAnalysis { (String) a[7], (Double) a[8], (Integer) a[9], - (Long) a[10])); + (Long) a[10], + (ClassAssignmentObjective) a[11])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -81,6 +85,12 @@ public class Classification implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); + PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ClassAssignmentObjective.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING); } private final String dependentVariable; @@ -92,13 +102,15 @@ public class Classification implements DataFrameAnalysis { private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; + private final ClassAssignmentObjective classAssignmentObjective; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maxTrees, @Nullable Double featureBagFraction, @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, - @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { + @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed, + @Nullable ClassAssignmentObjective classAssignmentObjective) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -108,6 +120,7 @@ public class Classification implements DataFrameAnalysis { this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; + this.classAssignmentObjective = classAssignmentObjective; this.numTopClasses = numTopClasses; this.randomizeSeed = randomizeSeed; } @@ -157,6 +170,10 @@ public class Classification implements DataFrameAnalysis { return randomizeSeed; } + public ClassAssignmentObjective getClassAssignmentObjective() { + return classAssignmentObjective; + } + public Integer getNumTopClasses() { return numTopClasses; } @@ -192,6 +209,9 @@ public class Classification implements DataFrameAnalysis { if (randomizeSeed != null) { builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); } + if (classAssignmentObjective != null) { + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); + } if (numTopClasses != null) { builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); } @@ -202,7 +222,7 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues, - predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); + predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective); } @Override @@ -220,7 +240,8 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) - && Objects.equals(numTopClasses, that.numTopClasses); + && Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(classAssignmentObjective, that.classAssignmentObjective); } @Override @@ -228,6 +249,19 @@ public class Classification implements DataFrameAnalysis { return Strings.toString(this); } + public enum ClassAssignmentObjective { + MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL; + + public static ClassAssignmentObjective fromString(String value) { + return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + public static class Builder { private String dependentVariable; private Double lambda; @@ -240,6 +274,7 @@ public class Classification implements DataFrameAnalysis { private Double trainingPercent; private Integer numTopClasses; private Long randomizeSeed; + private ClassAssignmentObjective classAssignmentObjective; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -295,9 +330,15 @@ public class Classification implements DataFrameAnalysis { return this; } + public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) { + this.classAssignmentObjective = classAssignmentObjective; + return this; + } + public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, - numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed, + classAssignmentObjective); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index c91f0f77728..57f6a82448c 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .setClassAssignmentObjective( + org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) .setNumTopClasses(1) .setLambda(1.0) .setGamma(1.0) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 92d97f5f5a1..8393774493e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfig; import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.client.ml.dataframe.Classification; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; @@ -2969,7 +2970,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // end::put-data-frame-analytics-outlier-detection-customized // tag::put-data-frame-analytics-classification - DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1> + DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1> .setLambda(1.0) // <2> .setGamma(5.5) // <3> .setEta(5.5) // <4> @@ -2979,7 +2980,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setPredictionFieldName("my_prediction_field_name") // <8> .setTrainingPercent(50.0) // <9> .setRandomizeSeed(1234L) // <10> - .setNumTopClasses(1) // <11> + .setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11> + .setNumTopClasses(1) // <12> .build(); // end::put-data-frame-analytics-classification diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 906003e2009..0970222c513 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -36,6 +36,7 @@ public class ClassificationTests extends AbstractXContentTestCase The name of the prediction field in the results object. <9> The percentage of training-eligible rows to be used in training. Defaults to 100%. <10> The seed to be used by the random generator that picks which rows are used in training. -<11> The number of top classes to be reported in the results. Defaults to 2. +<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall. +<12> The number of top classes to be reported in the results. Defaults to 2. ===== Regression diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index e03b22cf309..e8c82ba114f 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -138,6 +138,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma] (Optional, double) include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] +`analysis`.`classification`.`class_assignment_objective`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective] + `analysis`.`classification`.`num_top_classes`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 0c012abfe53..a1a11437e4d 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -334,6 +334,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode] include::{docdir}/ml/ml-shared.asciidoc[tag=time-span] end::chunking-config[] +tag::class-assignment-objective[] +Defines the objective to optimize when assigning class labels. Available +objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing +accuracy class labels are chosen to maximize the number of correct predictions. +When maximizing minimum recall labels are chosen to maximize the minimum recall +for any class. Defaults to maximize_minimum_recall. +end::class-assignment-objective[] + tag::compute-feature-influence[] If `true`, the feature influence calculation is enabled. Defaults to `true`. end::compute-feature-influence[] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index ece6f6a278b..ddb6a921200 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; @@ -21,6 +22,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -37,6 +39,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -54,12 +57,19 @@ public class Classification implements DataFrameAnalysis { (String) a[0], new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), (String) a[7], - (Integer) a[8], - (Double) a[9], - (Long) a[10])); + (ClassAssignmentObjective) a[8], + (Integer) a[9], + (Double) a[10], + (Long) a[11])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); + parser.declareField(optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ClassAssignmentObjective.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING); parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); @@ -91,6 +101,7 @@ public class Classification implements DataFrameAnalysis { private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; + private final ClassAssignmentObjective classAssignmentObjective; private final int numTopClasses; private final double trainingPercent; private final long randomizeSeed; @@ -98,6 +109,7 @@ public class Classification implements DataFrameAnalysis { public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, + @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer numTopClasses, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { @@ -110,19 +122,26 @@ public class Classification implements DataFrameAnalysis { this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; + this.classAssignmentObjective = classAssignmentObjective == null ? + ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Classification(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); } public Classification(StreamInput in) throws IOException { dependentVariable = in.readString(); boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class); + } else { + classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL; + } numTopClasses = in.readOptionalVInt(); trainingPercent = in.readDouble(); if (in.getVersion().onOrAfter(Version.V_7_6_0)) { @@ -144,6 +163,10 @@ public class Classification implements DataFrameAnalysis { return predictionFieldName; } + public ClassAssignmentObjective getClassAssignmentObjective() { + return classAssignmentObjective; + } + public int getNumTopClasses() { return numTopClasses; } @@ -166,6 +189,9 @@ public class Classification implements DataFrameAnalysis { out.writeString(dependentVariable); boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeEnum(classAssignmentObjective); + } out.writeOptionalVInt(numTopClasses); out.writeDouble(trainingPercent); if (out.getVersion().onOrAfter(Version.V_7_6_0)) { @@ -180,6 +206,7 @@ public class Classification implements DataFrameAnalysis { builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); @@ -197,6 +224,7 @@ public class Classification implements DataFrameAnalysis { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); + params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); @@ -305,6 +333,7 @@ public class Classification implements DataFrameAnalysis { return Objects.equals(dependentVariable, that.dependentVariable) && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) + && Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(numTopClasses, that.numTopClasses) && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed; @@ -312,6 +341,20 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective, + numTopClasses, trainingPercent, randomizeSeed); + } + + public enum ClassAssignmentObjective { + MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL; + + public static ClassAssignmentObjective fromString(String value) { + return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 9dbd9a3986d..2414664c297 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -318,6 +318,7 @@ public final class ReservedFieldNames { Classification.NAME.getPreferredName(), Classification.DEPENDENT_VARIABLE.getPreferredName(), Classification.PREDICTION_FIELD_NAME.getPreferredName(), + Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), Classification.NUM_TOP_CLASSES.getPreferredName(), Classification.TRAINING_PERCENT.getPreferredName(), BoostedTreeParams.LAMBDA.getPreferredName(), diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json index 961ad77f65e..0d4444209fe 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json @@ -43,6 +43,9 @@ "max_trees" : { "type" : "integer" }, + "class_assignment_objective" : { + "type" : "keyword" + }, "num_top_classes" : { "type" : "integer" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 8992530c1db..a04bbb70290 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -150,12 +150,14 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC bwcAnalysis = new Classification(bwcClassification.getDependentVariable(), bwcClassification.getBoostedTreeParams(), bwcClassification.getPredictionFieldName(), + bwcClassification.getClassAssignmentObjective(), bwcClassification.getNumTopClasses(), bwcClassification.getTrainingPercent(), 42L); testAnalysis = new Classification(testClassification.getDependentVariable(), testClassification.getBoostedTreeParams(), testClassification.getPredictionFieldName(), + testClassification.getClassAssignmentObjective(), testClassification.getNumTopClasses(), testClassification.getTrainingPercent(), 42L); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index a475602128e..ef3f4f0082c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -55,17 +55,20 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } + public void testClassAssignmentObjective() { + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", + Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)); + + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", + Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); + + // class_assignment_objective == null, default applied + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); + } + public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } @@ -177,6 +196,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase>allOf( hasEntry("dependent_variable", "foo"), + hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "foo_prediction"), hasEntry("prediction_field_type", "bool"))); @@ -184,6 +204,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase>allOf( hasEntry("dependent_variable", "bar"), + hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "bar_prediction"), hasEntry("prediction_field_type", "int"))); @@ -191,6 +212,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase>allOf( hasEntry("dependent_variable", "baz"), + hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), hasEntry("num_top_classes", 2), hasEntry("prediction_field_name", "baz_prediction"), hasEntry("prediction_field_type", "string"))); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 4269d8570c6..e403f4cf8b4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -89,6 +89,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { null, null, null, + null, null)); registerAnalytics(config); putAnalytics(config); @@ -190,7 +191,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -438,7 +439,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null)); registerAnalytics(firstJob); putAnalytics(firstJob); @@ -447,7 +448,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed)); registerAnalytics(secondJob); putAnalytics(secondJob); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 8884941ddfa..a5a99b30391 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1834,6 +1834,7 @@ setup: "eta": 0.5, "max_trees": 400, "feature_bag_fraction": 0.3, + "class_assignment_objective": "maximize_accuracy", "training_percent": 60.3, "randomize_seed": 24 } @@ -1853,6 +1854,7 @@ setup: "prediction_field_name": "foo_prediction", "training_percent": 60.3, "randomize_seed": 24, + "class_assignment_objective": "maximize_accuracy", "num_top_classes": 2 } }} @@ -1896,6 +1898,7 @@ setup: "prediction_field_name": "foo_prediction", "training_percent": 100.0, "randomize_seed": 24, + "class_assignment_objective": "maximize_minimum_recall", "num_top_classes": 2 } }}