This adds training_percent parameter to the analytics process for Classification and Regression. This parameter is then used to give more accurate memory estimations. See native side pr: elastic/ml-cpp#1111
This commit is contained in:
parent
54ea4f4f50
commit
7fe38935f6
|
@ -241,6 +241,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
|
||||
}
|
||||
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
|
||||
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
@ -163,6 +163,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (predictionFieldName != null) {
|
||||
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
|||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -201,31 +200,43 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
|
||||
assertThat(
|
||||
new Classification("foo").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "foo"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "foo_prediction"),
|
||||
hasEntry("prediction_field_type", "bool"),
|
||||
hasEntry("num_classes", 10L)));
|
||||
equalTo(
|
||||
org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "foo",
|
||||
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
||||
"num_top_classes", 2,
|
||||
"prediction_field_name", "foo_prediction",
|
||||
"prediction_field_type", "bool",
|
||||
"num_classes", 10L,
|
||||
"training_percent", 100.0)));
|
||||
assertThat(
|
||||
new Classification("bar").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "bar"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "bar_prediction"),
|
||||
hasEntry("prediction_field_type", "int"),
|
||||
hasEntry("num_classes", 20L)));
|
||||
equalTo(
|
||||
org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "bar",
|
||||
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
||||
"num_top_classes", 2,
|
||||
"prediction_field_name", "bar_prediction",
|
||||
"prediction_field_type", "int",
|
||||
"num_classes", 20L,
|
||||
"training_percent", 100.0)));
|
||||
assertThat(
|
||||
new Classification("baz").getParams(fieldInfo),
|
||||
Matchers.<Map<String, Object>>allOf(
|
||||
hasEntry("dependent_variable", "baz"),
|
||||
hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL),
|
||||
hasEntry("num_top_classes", 2),
|
||||
hasEntry("prediction_field_name", "baz_prediction"),
|
||||
hasEntry("prediction_field_type", "string"),
|
||||
hasEntry("num_classes", 30L)));
|
||||
new Classification("baz",
|
||||
BoostedTreeParams.builder().build() ,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
50.0,
|
||||
null).getParams(fieldInfo),
|
||||
equalTo(
|
||||
org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "baz",
|
||||
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
||||
"num_top_classes", 2,
|
||||
"prediction_field_name", "baz_prediction",
|
||||
"prediction_field_type", "string",
|
||||
"num_classes", 30L,
|
||||
"training_percent", 50.0)));
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
|
|
@ -19,7 +19,6 @@ import java.io.IOException;
|
|||
import java.util.Map;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -132,7 +131,20 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
public void testGetParams() {
|
||||
assertThat(
|
||||
new Regression("foo").getParams(null),
|
||||
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction")));
|
||||
equalTo(org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "foo",
|
||||
"prediction_field_name", "foo_prediction",
|
||||
"training_percent", 100.0)));
|
||||
assertThat(
|
||||
new Regression("foo",
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
50.0,
|
||||
null).getParams(null),
|
||||
equalTo(org.elasticsearch.common.collect.Map.of(
|
||||
"dependent_variable", "foo",
|
||||
"prediction_field_name", "foo_prediction",
|
||||
"training_percent", 50.0)));
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
|
|
@ -10,17 +10,21 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
|||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
@ -81,4 +85,41 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
|
||||
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(1024L));
|
||||
}
|
||||
|
||||
public void testTrainingPercentageIsApplied() throws IOException {
|
||||
String sourceIndex = "test-training-percentage-applied";
|
||||
RegressionIT.indexData(sourceIndex, 100, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("dfa-training-100-" + sourceIndex)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
|
||||
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()),
|
||||
null))
|
||||
.setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD,
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
100.0,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
ExplainDataFrameAnalyticsAction.Response explainResponse = explainDataFrame(config);
|
||||
|
||||
ByteSizeValue allDataUsedForTraining = explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk();
|
||||
|
||||
config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId("dfa-training-50-" + sourceIndex)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
|
||||
QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()),
|
||||
null))
|
||||
.setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD,
|
||||
BoostedTreeParams.builder().build(),
|
||||
null,
|
||||
50.0,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
explainResponse = explainDataFrame(config);
|
||||
|
||||
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk(), lessThan(allDataUsedForTraining));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
|||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
@ -40,10 +39,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
|
||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
|
||||
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
||||
static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0, 3.0);
|
||||
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(10L, 20L, 30L);
|
||||
private static final List<Double> DEPENDENT_VARIABLE_VALUES = org.elasticsearch.common.collect.List.of(10.0, 20.0, 30.0);
|
||||
|
||||
private String jobId;
|
||||
private String sourceIndex;
|
||||
|
@ -399,7 +398,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
this.destIndex = sourceIndex + "_results";
|
||||
}
|
||||
|
||||
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
||||
static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc",
|
||||
NUMERICAL_FEATURE_FIELD, "type=double",
|
||||
|
|
Loading…
Reference in New Issue