[ML] add training_percent to analytics process params (#54605) (#54678)

This adds training_percent parameter to the analytics process for Classification and Regression. This parameter is then used to give more accurate memory estimations.

See native side pr: elastic/ml-cpp#1111
This commit is contained in:
Benjamin Trent 2020-04-02 17:08:06 -04:00 committed by GitHub
parent 54ea4f4f50
commit 7fe38935f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 31 deletions

View File

@ -241,6 +241,7 @@ public class Classification implements DataFrameAnalysis {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType); params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
} }
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
return params; return params;
} }

View File

@ -163,6 +163,7 @@ public class Regression implements DataFrameAnalysis {
if (predictionFieldName != null) { if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
} }
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
return params; return params;
} }

View File

@ -17,7 +17,6 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.hamcrest.Matchers;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -201,31 +200,43 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat( assertThat(
new Classification("foo").getParams(fieldInfo), new Classification("foo").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf( equalTo(
hasEntry("dependent_variable", "foo"), org.elasticsearch.common.collect.Map.of(
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), "dependent_variable", "foo",
hasEntry("num_top_classes", 2), "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
hasEntry("prediction_field_name", "foo_prediction"), "num_top_classes", 2,
hasEntry("prediction_field_type", "bool"), "prediction_field_name", "foo_prediction",
hasEntry("num_classes", 10L))); "prediction_field_type", "bool",
"num_classes", 10L,
"training_percent", 100.0)));
assertThat( assertThat(
new Classification("bar").getParams(fieldInfo), new Classification("bar").getParams(fieldInfo),
Matchers.<Map<String, Object>>allOf( equalTo(
hasEntry("dependent_variable", "bar"), org.elasticsearch.common.collect.Map.of(
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), "dependent_variable", "bar",
hasEntry("num_top_classes", 2), "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
hasEntry("prediction_field_name", "bar_prediction"), "num_top_classes", 2,
hasEntry("prediction_field_type", "int"), "prediction_field_name", "bar_prediction",
hasEntry("num_classes", 20L))); "prediction_field_type", "int",
"num_classes", 20L,
"training_percent", 100.0)));
assertThat( assertThat(
new Classification("baz").getParams(fieldInfo), new Classification("baz",
Matchers.<Map<String, Object>>allOf( BoostedTreeParams.builder().build() ,
hasEntry("dependent_variable", "baz"), null,
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), null,
hasEntry("num_top_classes", 2), null,
hasEntry("prediction_field_name", "baz_prediction"), 50.0,
hasEntry("prediction_field_type", "string"), null).getParams(fieldInfo),
hasEntry("num_classes", 30L))); equalTo(
org.elasticsearch.common.collect.Map.of(
"dependent_variable", "baz",
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
"num_top_classes", 2,
"prediction_field_name", "baz_prediction",
"prediction_field_type", "string",
"num_classes", 30L,
"training_percent", 50.0)));
} }
public void testRequiredFieldsIsNonEmpty() { public void testRequiredFieldsIsNonEmpty() {

View File

@ -19,7 +19,6 @@ import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.Collections; import java.util.Collections;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -132,7 +131,20 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
public void testGetParams() { public void testGetParams() {
assertThat( assertThat(
new Regression("foo").getParams(null), new Regression("foo").getParams(null),
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction"))); equalTo(org.elasticsearch.common.collect.Map.of(
"dependent_variable", "foo",
"prediction_field_name", "foo_prediction",
"training_percent", 100.0)));
assertThat(
new Regression("foo",
BoostedTreeParams.builder().build(),
null,
50.0,
null).getParams(null),
equalTo(org.elasticsearch.common.collect.Map.of(
"dependent_variable", "foo",
"prediction_field_name", "foo_prediction",
"training_percent", 50.0)));
} }
public void testRequiredFieldsIsNonEmpty() { public void testRequiredFieldsIsNonEmpty() {

View File

@ -10,17 +10,21 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import java.io.IOException; import java.io.IOException;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@ -81,4 +85,41 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(1024L)); assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(1024L));
} }
public void testTrainingPercentageIsApplied() throws IOException {
String sourceIndex = "test-training-percentage-applied";
RegressionIT.indexData(sourceIndex, 100, 0);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("dfa-training-100-" + sourceIndex)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()),
null))
.setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD,
BoostedTreeParams.builder().build(),
null,
100.0,
null))
.buildForExplain();
ExplainDataFrameAnalyticsAction.Response explainResponse = explainDataFrame(config);
ByteSizeValue allDataUsedForTraining = explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk();
config = new DataFrameAnalyticsConfig.Builder()
.setId("dfa-training-50-" + sourceIndex)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()),
null))
.setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD,
BoostedTreeParams.builder().build(),
null,
50.0,
null))
.buildForExplain();
explainResponse = explainDataFrame(config);
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk(), lessThan(allDataUsedForTraining));
}
} }

View File

@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.junit.After; import org.junit.After;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -40,10 +39,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature"; private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature"; private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable"; static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); private static final List<Double> NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0, 3.0);
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L)); private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(10L, 20L, 30L);
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0)); private static final List<Double> DEPENDENT_VARIABLE_VALUES = org.elasticsearch.common.collect.List.of(10.0, 20.0, 30.0);
private String jobId; private String jobId;
private String sourceIndex; private String sourceIndex;
@ -399,7 +398,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
this.destIndex = sourceIndex + "_results"; this.destIndex = sourceIndex + "_results";
} }
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) { static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
client().admin().indices().prepareCreate(sourceIndex) client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", .addMapping("_doc",
NUMERICAL_FEATURE_FIELD, "type=double", NUMERICAL_FEATURE_FIELD, "type=double",