[ML] Account for per-partition categorization in model memory estimate (#59458)

Now that we have per-partition categorization, the estimate for
the model memory limit required for a particular analysis config
needs to take into account whether categorization is operating
for the job as a whole or per-partition.
This commit is contained in:
David Roberts 2020-07-14 09:13:30 +01:00
parent 4350add12c
commit 529aa345df
2 changed files with 54 additions and 10 deletions

View File

@ -57,7 +57,7 @@ public class TransportEstimateModelMemoryAction
long answer = BASIC_REQUIREMENT.getBytes();
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateDetectorsRequirementBytes(analysisConfig, overallCardinality));
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality));
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateCategorizationRequirementBytes(analysisConfig));
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateCategorizationRequirementBytes(analysisConfig, overallCardinality));
listener.onResponse(new EstimateModelMemoryAction.Response(roundUpToNextMb(answer)));
}
@ -194,14 +194,29 @@ public class TransportEstimateModelMemoryAction
return multiplyNonNegativeLongsWithMaxValueCap(BYTES_PER_INFLUENCER_VALUE, totalInfluencerCardinality);
}
static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig) {
static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
if (analysisConfig.getCategorizationFieldName() == null) {
return 0;
}
long relevantPartitionFieldCardinalityEstimate = 1;
if (analysisConfig.getPerPartitionCategorizationConfig().isEnabled()) {
// It is enforced that only one partition field name be configured when per-partition categorization
// is enabled, so we can stop after finding a non-null partition field name
for (Detector detector : analysisConfig.getDetectors()) {
String partitionFieldName = detector.getPartitionFieldName();
if (partitionFieldName != null) {
relevantPartitionFieldCardinalityEstimate = Math.max(1, cardinalityEstimate(
Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true));
break;
}
}
}
// 5MB is a pretty conservative estimate of the memory requirement for categorization.
// Often it is considerably less, but it's very hard to predict from simple statistics.
return new ByteSizeValue(5, ByteSizeUnit.MB).getBytes();
return new ByteSizeValue(5 * relevantPartitionFieldCardinalityEstimate, ByteSizeUnit.MB).getBytes();
}
static long cardinalityEstimate(String description, String fieldName, Map<String, Long> suppliedCardinailityEstimates,

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
import java.util.Arrays;
import java.util.Collections;
@ -77,16 +78,38 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
maxBucketCardinality), is((200 + 300) * TransportEstimateModelMemoryAction.BYTES_PER_INFLUENCER_VALUE));
}
public void testCalculateCategorizationRequirementBytes() {
public void testCalculateCategorizationRequirementBytesNoCategorization() {
AnalysisConfig analysisConfigWithoutCategorization = createCountAnalysisConfig(null, null);
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithoutCategorization), is(0L));
Map<String, Long> overallCardinality = new HashMap<>();
overallCardinality.put("part", randomLongBetween(10, 1000));
AnalysisConfig analysisConfigWithCategorization = createCountAnalysisConfig(randomAlphaOfLength(10), null);
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithCategorization),
AnalysisConfig analysisConfig = createCountAnalysisConfig(null, randomBoolean() ? "part" : null);
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality), is(0L));
}
public void testCalculateCategorizationRequirementBytesSimpleCategorization() {
Map<String, Long> overallCardinality = new HashMap<>();
overallCardinality.put("part", randomLongBetween(10, 1000));
AnalysisConfig analysisConfig =
createCountAnalysisConfig(randomAlphaOfLength(10), randomBoolean() ? "part" : null);
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality),
is(5L * 1024 * 1024));
}
public void testCalculateCategorizationRequirementBytesPerPartitionCategorization() {
long partitionCardinality = randomLongBetween(10, 1000);
Map<String, Long> overallCardinality = new HashMap<>();
overallCardinality.put("part", partitionCardinality);
AnalysisConfig analysisConfig = createCountAnalysisConfigBuilder(randomAlphaOfLength(10), "part")
.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean())).build();
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality),
is(partitionCardinality * 5L * 1024 * 1024));
}
public void testRoundUpToNextMb() {
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(0),
@ -168,9 +191,15 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
public static AnalysisConfig createCountAnalysisConfig(String categorizationFieldName, String partitionFieldName,
String... influencerFieldNames) {
return createCountAnalysisConfigBuilder(categorizationFieldName, partitionFieldName, influencerFieldNames).build();
}
public static AnalysisConfig.Builder createCountAnalysisConfigBuilder(String categorizationFieldName, String partitionFieldName,
String... influencerFieldNames) {
Detector.Builder detectorBuilder = new Detector.Builder("count", null);
detectorBuilder.setPartitionFieldName((categorizationFieldName != null) ? AnalysisConfig.ML_CATEGORY_FIELD : partitionFieldName);
detectorBuilder.setByFieldName((categorizationFieldName != null) ? AnalysisConfig.ML_CATEGORY_FIELD : null);
detectorBuilder.setPartitionFieldName(partitionFieldName);
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()));
@ -182,6 +211,6 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
builder.setInfluencers(Arrays.asList(influencerFieldNames));
}
return builder.build();
return builder;
}
}