mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-17 02:14:54 +00:00
[ML] Account for per-partition categorization in model memory estimate (#59458)
Now that we have per-partition categorization, the estimate for the model memory limit required for a particular analysis config needs to take into account whether categorization is operating for the job as a whole or per-partition.
This commit is contained in:
parent
4350add12c
commit
529aa345df
@ -57,7 +57,7 @@ public class TransportEstimateModelMemoryAction
|
||||
long answer = BASIC_REQUIREMENT.getBytes();
|
||||
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateDetectorsRequirementBytes(analysisConfig, overallCardinality));
|
||||
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality));
|
||||
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateCategorizationRequirementBytes(analysisConfig));
|
||||
answer = addNonNegativeLongsWithMaxValueCap(answer, calculateCategorizationRequirementBytes(analysisConfig, overallCardinality));
|
||||
|
||||
listener.onResponse(new EstimateModelMemoryAction.Response(roundUpToNextMb(answer)));
|
||||
}
|
||||
@ -194,14 +194,29 @@ public class TransportEstimateModelMemoryAction
|
||||
return multiplyNonNegativeLongsWithMaxValueCap(BYTES_PER_INFLUENCER_VALUE, totalInfluencerCardinality);
|
||||
}
|
||||
|
||||
static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig) {
|
||||
static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
|
||||
|
||||
if (analysisConfig.getCategorizationFieldName() == null) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
long relevantPartitionFieldCardinalityEstimate = 1;
|
||||
if (analysisConfig.getPerPartitionCategorizationConfig().isEnabled()) {
|
||||
// It is enforced that only one partition field name be configured when per-partition categorization
|
||||
// is enabled, so we can stop after finding a non-null partition field name
|
||||
for (Detector detector : analysisConfig.getDetectors()) {
|
||||
String partitionFieldName = detector.getPartitionFieldName();
|
||||
if (partitionFieldName != null) {
|
||||
relevantPartitionFieldCardinalityEstimate = Math.max(1, cardinalityEstimate(
|
||||
Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5MB is a pretty conservative estimate of the memory requirement for categorization.
|
||||
// Often it is considerably less, but it's very hard to predict from simple statistics.
|
||||
return new ByteSizeValue(5, ByteSizeUnit.MB).getBytes();
|
||||
return new ByteSizeValue(5 * relevantPartitionFieldCardinalityEstimate, ByteSizeUnit.MB).getBytes();
|
||||
}
|
||||
|
||||
static long cardinalityEstimate(String description, String fieldName, Map<String, Long> suppliedCardinailityEstimates,
|
||||
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Detector;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
@ -77,16 +78,38 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
|
||||
maxBucketCardinality), is((200 + 300) * TransportEstimateModelMemoryAction.BYTES_PER_INFLUENCER_VALUE));
|
||||
}
|
||||
|
||||
public void testCalculateCategorizationRequirementBytes() {
|
||||
public void testCalculateCategorizationRequirementBytesNoCategorization() {
|
||||
|
||||
AnalysisConfig analysisConfigWithoutCategorization = createCountAnalysisConfig(null, null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithoutCategorization), is(0L));
|
||||
Map<String, Long> overallCardinality = new HashMap<>();
|
||||
overallCardinality.put("part", randomLongBetween(10, 1000));
|
||||
|
||||
AnalysisConfig analysisConfigWithCategorization = createCountAnalysisConfig(randomAlphaOfLength(10), null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithCategorization),
|
||||
AnalysisConfig analysisConfig = createCountAnalysisConfig(null, randomBoolean() ? "part" : null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality), is(0L));
|
||||
}
|
||||
|
||||
public void testCalculateCategorizationRequirementBytesSimpleCategorization() {
|
||||
|
||||
Map<String, Long> overallCardinality = new HashMap<>();
|
||||
overallCardinality.put("part", randomLongBetween(10, 1000));
|
||||
|
||||
AnalysisConfig analysisConfig =
|
||||
createCountAnalysisConfig(randomAlphaOfLength(10), randomBoolean() ? "part" : null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality),
|
||||
is(5L * 1024 * 1024));
|
||||
}
|
||||
|
||||
public void testCalculateCategorizationRequirementBytesPerPartitionCategorization() {
|
||||
|
||||
long partitionCardinality = randomLongBetween(10, 1000);
|
||||
Map<String, Long> overallCardinality = new HashMap<>();
|
||||
overallCardinality.put("part", partitionCardinality);
|
||||
|
||||
AnalysisConfig analysisConfig = createCountAnalysisConfigBuilder(randomAlphaOfLength(10), "part")
|
||||
.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean())).build();
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfig, overallCardinality),
|
||||
is(partitionCardinality * 5L * 1024 * 1024));
|
||||
}
|
||||
|
||||
public void testRoundUpToNextMb() {
|
||||
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(0),
|
||||
@ -168,9 +191,15 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
|
||||
|
||||
public static AnalysisConfig createCountAnalysisConfig(String categorizationFieldName, String partitionFieldName,
|
||||
String... influencerFieldNames) {
|
||||
return createCountAnalysisConfigBuilder(categorizationFieldName, partitionFieldName, influencerFieldNames).build();
|
||||
}
|
||||
|
||||
public static AnalysisConfig.Builder createCountAnalysisConfigBuilder(String categorizationFieldName, String partitionFieldName,
|
||||
String... influencerFieldNames) {
|
||||
|
||||
Detector.Builder detectorBuilder = new Detector.Builder("count", null);
|
||||
detectorBuilder.setPartitionFieldName((categorizationFieldName != null) ? AnalysisConfig.ML_CATEGORY_FIELD : partitionFieldName);
|
||||
detectorBuilder.setByFieldName((categorizationFieldName != null) ? AnalysisConfig.ML_CATEGORY_FIELD : null);
|
||||
detectorBuilder.setPartitionFieldName(partitionFieldName);
|
||||
|
||||
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()));
|
||||
|
||||
@ -182,6 +211,6 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
|
||||
builder.setInfluencers(Arrays.asList(influencerFieldNames));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user