[ML] Skeleton estimate_model_memory endpoint for anomaly detection (#53386)
This is a partial implementation of an endpoint for anomaly detector model memory estimation. It is not complete, lacking docs, HLRC and sensible numbers for many anomaly detector configurations. These will be added in a followup PR in time for 7.7 feature freeze. A skeleton endpoint is useful now because it allows work on the UI side of the change to commence. The skeleton endpoint handles the same cases that the old UI code used to handle, and produces very similar estimates for these cases. Backport of #53333
This commit is contained in:
parent
ac721938c2
commit
532a720e1b
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionRequestBuilder;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.client.ElasticsearchClient;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class EstimateModelMemoryAction extends ActionType<EstimateModelMemoryAction.Response> {
|
||||
|
||||
public static final EstimateModelMemoryAction INSTANCE = new EstimateModelMemoryAction();
|
||||
public static final String NAME = "cluster:admin/xpack/ml/job/estimate_model_memory";
|
||||
|
||||
private EstimateModelMemoryAction() {
|
||||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends ActionRequest {
|
||||
|
||||
public static final ParseField ANALYSIS_CONFIG = Job.ANALYSIS_CONFIG;
|
||||
public static final ParseField OVERALL_CARDINALITY = new ParseField("overall_cardinality");
|
||||
public static final ParseField MAX_BUCKET_CARDINALITY = new ParseField("max_bucket_cardinality");
|
||||
|
||||
public static final ObjectParser<Request, Void> PARSER =
|
||||
new ObjectParser<>(NAME, EstimateModelMemoryAction.Request::new);
|
||||
|
||||
static {
|
||||
PARSER.declareObject(Request::setAnalysisConfig, (p, c) -> AnalysisConfig.STRICT_PARSER.apply(p, c).build(), ANALYSIS_CONFIG);
|
||||
PARSER.declareObject(Request::setOverallCardinality,
|
||||
(p, c) -> p.map(HashMap::new, parser -> Request.parseNonNegativeLong(parser, OVERALL_CARDINALITY)),
|
||||
OVERALL_CARDINALITY);
|
||||
PARSER.declareObject(Request::setMaxBucketCardinality,
|
||||
(p, c) -> p.map(HashMap::new, parser -> Request.parseNonNegativeLong(parser, MAX_BUCKET_CARDINALITY)),
|
||||
MAX_BUCKET_CARDINALITY);
|
||||
}
|
||||
|
||||
public static Request parseRequest(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private AnalysisConfig analysisConfig;
|
||||
private Map<String, Long> overallCardinality = Collections.emptyMap();
|
||||
private Map<String, Long> maxBucketCardinality = Collections.emptyMap();
|
||||
|
||||
public Request() {
|
||||
super();
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.analysisConfig = in.readBoolean() ? new AnalysisConfig(in) : null;
|
||||
this.overallCardinality = in.readMap(StreamInput::readString, StreamInput::readVLong);
|
||||
this.maxBucketCardinality = in.readMap(StreamInput::readString, StreamInput::readVLong);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
if (analysisConfig != null) {
|
||||
out.writeBoolean(true);
|
||||
analysisConfig.writeTo(out);
|
||||
} else {
|
||||
out.writeBoolean(false);
|
||||
}
|
||||
out.writeMap(overallCardinality, StreamOutput::writeString, StreamOutput::writeVLong);
|
||||
out.writeMap(maxBucketCardinality, StreamOutput::writeString, StreamOutput::writeVLong);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
if (analysisConfig == null) {
|
||||
ActionRequestValidationException e = new ActionRequestValidationException();
|
||||
e.addValidationError("[" + ANALYSIS_CONFIG.getPreferredName() + "] was not specified");
|
||||
return e;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public AnalysisConfig getAnalysisConfig() {
|
||||
return analysisConfig;
|
||||
}
|
||||
|
||||
public void setAnalysisConfig(AnalysisConfig analysisConfig) {
|
||||
this.analysisConfig = ExceptionsHelper.requireNonNull(analysisConfig, ANALYSIS_CONFIG);
|
||||
}
|
||||
|
||||
public Map<String, Long> getOverallCardinality() {
|
||||
return overallCardinality;
|
||||
}
|
||||
|
||||
public void setOverallCardinality(Map<String, Long> overallCardinality) {
|
||||
this.overallCardinality =
|
||||
Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(overallCardinality, OVERALL_CARDINALITY));
|
||||
}
|
||||
|
||||
public Map<String, Long> getMaxBucketCardinality() {
|
||||
return maxBucketCardinality;
|
||||
}
|
||||
|
||||
public void setMaxBucketCardinality(Map<String, Long> maxBucketCardinality) {
|
||||
this.maxBucketCardinality =
|
||||
Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(maxBucketCardinality, MAX_BUCKET_CARDINALITY));
|
||||
}
|
||||
|
||||
private static long parseNonNegativeLong(XContentParser parser, ParseField enclosingField) throws IOException {
|
||||
long value = parser.longValue();
|
||||
if (value < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] contained negative cardinality [{}]",
|
||||
enclosingField.getPreferredName(), value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
|
||||
|
||||
public RequestBuilder(ElasticsearchClient client, EstimateModelMemoryAction action) {
|
||||
super(client, action, new Request());
|
||||
}
|
||||
}
|
||||
|
||||
public static class Response extends ActionResponse implements ToXContentObject {
|
||||
|
||||
private static final ParseField MODEL_MEMORY_ESTIMATE = new ParseField("model_memory_estimate");
|
||||
|
||||
private final ByteSizeValue modelMemoryEstimate;
|
||||
|
||||
public Response(ByteSizeValue modelMemoryEstimate) {
|
||||
this.modelMemoryEstimate = Objects.requireNonNull(modelMemoryEstimate);
|
||||
}
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
modelMemoryEstimate = new ByteSizeValue(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
modelMemoryEstimate.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MODEL_MEMORY_ESTIMATE.getPreferredName(), modelMemoryEstimate.getStringRep());
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public ByteSizeValue getModelMemoryEstimate() {
|
||||
return modelMemoryEstimate;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -72,6 +72,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
|
||||
|
@ -143,6 +144,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction;
|
||||
|
@ -278,6 +280,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
|
|||
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestEstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestFlushJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestForecastJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestGetJobStatsAction;
|
||||
|
@ -745,6 +748,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
new RestFlushJobAction(),
|
||||
new RestValidateDetectorAction(),
|
||||
new RestValidateJobConfigAction(),
|
||||
new RestEstimateModelMemoryAction(),
|
||||
new RestGetCategoriesAction(),
|
||||
new RestGetModelSnapshotsAction(),
|
||||
new RestRevertModelSnapshotAction(),
|
||||
|
@ -819,6 +823,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
new ActionHandler<>(FlushJobAction.INSTANCE, TransportFlushJobAction.class),
|
||||
new ActionHandler<>(ValidateDetectorAction.INSTANCE, TransportValidateDetectorAction.class),
|
||||
new ActionHandler<>(ValidateJobConfigAction.INSTANCE, TransportValidateJobConfigAction.class),
|
||||
new ActionHandler<>(EstimateModelMemoryAction.INSTANCE, TransportEstimateModelMemoryAction.class),
|
||||
new ActionHandler<>(GetCategoriesAction.INSTANCE, TransportGetCategoriesAction.class),
|
||||
new ActionHandler<>(GetModelSnapshotsAction.INSTANCE, TransportGetModelSnapshotsAction.class),
|
||||
new ActionHandler<>(RevertModelSnapshotAction.INSTANCE, TransportRevertModelSnapshotAction.class),
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Detector;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class TransportEstimateModelMemoryAction
|
||||
extends HandledTransportAction<EstimateModelMemoryAction.Request, EstimateModelMemoryAction.Response> {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(TransportEstimateModelMemoryAction.class);
|
||||
|
||||
static final ByteSizeValue BASIC_REQUIREMENT = new ByteSizeValue(10, ByteSizeUnit.MB);
|
||||
static final long BYTES_PER_INFLUENCER_VALUE = new ByteSizeValue(10, ByteSizeUnit.KB).getBytes();
|
||||
private static final long BYTES_IN_MB = new ByteSizeValue(1, ByteSizeUnit.MB).getBytes();
|
||||
|
||||
@Inject
|
||||
public TransportEstimateModelMemoryAction(TransportService transportService,
|
||||
ActionFilters actionFilters) {
|
||||
super(EstimateModelMemoryAction.NAME, transportService, actionFilters, EstimateModelMemoryAction.Request::new);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task,
|
||||
EstimateModelMemoryAction.Request request,
|
||||
ActionListener<EstimateModelMemoryAction.Response> listener) {
|
||||
|
||||
AnalysisConfig analysisConfig = request.getAnalysisConfig();
|
||||
Map<String, Long> overallCardinality = request.getOverallCardinality();
|
||||
Map<String, Long> maxBucketCardinality = request.getMaxBucketCardinality();
|
||||
|
||||
long answer = BASIC_REQUIREMENT.getBytes()
|
||||
+ calculateDetectorsRequirementBytes(analysisConfig, overallCardinality)
|
||||
+ calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality)
|
||||
+ calculateCategorizationRequirementBytes(analysisConfig);
|
||||
|
||||
listener.onResponse(new EstimateModelMemoryAction.Response(roundUpToNextMb(answer)));
|
||||
}
|
||||
|
||||
static long calculateDetectorsRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
|
||||
return analysisConfig.getDetectors().stream().map(detector -> calculateDetectorRequirementBytes(detector, overallCardinality))
|
||||
.reduce(0L, Long::sum);
|
||||
}
|
||||
|
||||
static long calculateDetectorRequirementBytes(Detector detector, Map<String, Long> overallCardinality) {
|
||||
|
||||
long answer = 0;
|
||||
|
||||
switch (detector.getFunction()) {
|
||||
case COUNT:
|
||||
case LOW_COUNT:
|
||||
case HIGH_COUNT:
|
||||
case NON_ZERO_COUNT:
|
||||
case LOW_NON_ZERO_COUNT:
|
||||
case HIGH_NON_ZERO_COUNT:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case DISTINCT_COUNT:
|
||||
case LOW_DISTINCT_COUNT:
|
||||
case HIGH_DISTINCT_COUNT:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case RARE:
|
||||
case FREQ_RARE:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case INFO_CONTENT:
|
||||
case LOW_INFO_CONTENT:
|
||||
case HIGH_INFO_CONTENT:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case METRIC:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case MEAN:
|
||||
case LOW_MEAN:
|
||||
case HIGH_MEAN:
|
||||
case AVG:
|
||||
case LOW_AVG:
|
||||
case HIGH_AVG:
|
||||
case MIN:
|
||||
case MAX:
|
||||
case SUM:
|
||||
case LOW_SUM:
|
||||
case HIGH_SUM:
|
||||
case NON_NULL_SUM:
|
||||
case LOW_NON_NULL_SUM:
|
||||
case HIGH_NON_NULL_SUM:
|
||||
// 64 comes from https://github.com/elastic/kibana/issues/18722
|
||||
answer = new ByteSizeValue(64, ByteSizeUnit.KB).getBytes();
|
||||
break;
|
||||
case MEDIAN:
|
||||
case LOW_MEDIAN:
|
||||
case HIGH_MEDIAN:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case VARP:
|
||||
case LOW_VARP:
|
||||
case HIGH_VARP:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case TIME_OF_DAY:
|
||||
case TIME_OF_WEEK:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
case LAT_LONG:
|
||||
answer = 1; // TODO add realistic number
|
||||
break;
|
||||
default:
|
||||
assert false : "unhandled detector function: " + detector.getFunction().getFullName();
|
||||
}
|
||||
|
||||
String byFieldName = detector.getByFieldName();
|
||||
if (byFieldName != null) {
|
||||
answer *= cardinalityEstimate(Detector.BY_FIELD_NAME_FIELD.getPreferredName(), byFieldName, overallCardinality, true);
|
||||
}
|
||||
|
||||
String overFieldName = detector.getOverFieldName();
|
||||
if (overFieldName != null) {
|
||||
cardinalityEstimate(Detector.OVER_FIELD_NAME_FIELD.getPreferredName(), overFieldName, overallCardinality, true);
|
||||
// TODO - how should "over" field cardinality affect estimate?
|
||||
}
|
||||
|
||||
String partitionFieldName = detector.getPartitionFieldName();
|
||||
if (partitionFieldName != null) {
|
||||
answer *=
|
||||
cardinalityEstimate(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true);
|
||||
}
|
||||
|
||||
return answer;
|
||||
}
|
||||
|
||||
static long calculateInfluencerRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> maxBucketCardinality) {
|
||||
|
||||
// Influencers that are also by/over/partition fields do not consume extra memory by being influencers
|
||||
Set<String> pureInfluencers = new HashSet<>(analysisConfig.getInfluencers());
|
||||
for (Detector detector : analysisConfig.getDetectors()) {
|
||||
pureInfluencers.removeAll(detector.extractAnalysisFields());
|
||||
}
|
||||
|
||||
return pureInfluencers.stream()
|
||||
.map(influencer -> cardinalityEstimate(AnalysisConfig.INFLUENCERS.getPreferredName(), influencer, maxBucketCardinality, false)
|
||||
* BYTES_PER_INFLUENCER_VALUE)
|
||||
.reduce(0L, Long::sum);
|
||||
}
|
||||
|
||||
static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig) {
|
||||
|
||||
if (analysisConfig.getCategorizationFieldName() != null) {
|
||||
return 1; // TODO add realistic number
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static long cardinalityEstimate(String description, String fieldName, Map<String, Long> suppliedCardinailityEstimates,
|
||||
boolean isOverall) {
|
||||
Long suppliedEstimate = suppliedCardinailityEstimates.get(fieldName);
|
||||
if (suppliedEstimate != null) {
|
||||
return suppliedEstimate;
|
||||
}
|
||||
// Don't expect the user to supply cardinality estimates for the mlcategory field that we create ourselves
|
||||
if (AnalysisConfig.ML_CATEGORY_FIELD.equals(fieldName)) {
|
||||
return isOverall ? 500 : 50;
|
||||
}
|
||||
logger.warn("[{}] cardinality estimate required for [{}] [{}] but not supplied",
|
||||
isOverall ? "Overall" : "Bucket max", description, fieldName);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static ByteSizeValue roundUpToNextMb(long bytes) {
|
||||
assert bytes >= 0;
|
||||
return new ByteSizeValue((BYTES_IN_MB - 1 + bytes) / BYTES_IN_MB, ByteSizeUnit.MB);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.job;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateModelMemoryAction;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.POST;
|
||||
|
||||
public class RestEstimateModelMemoryAction extends BaseRestHandler {
|
||||
|
||||
@Override
|
||||
public List<Route> routes() {
|
||||
return Collections.singletonList(new Route(POST, MachineLearning.BASE_PATH + "anomaly_detectors/_estimate_model_memory"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "ml_estimate_model_memory_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
EstimateModelMemoryAction.Request request =
|
||||
EstimateModelMemoryAction.Request.parseRequest(restRequest.contentOrSourceParamParser());
|
||||
return channel -> client.execute(EstimateModelMemoryAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Detector;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class TransportEstimateModelMemoryActionTests extends ESTestCase {
|
||||
|
||||
public void testCalculateDetectorRequirementBytes() {
|
||||
|
||||
Map<String, Long> overallCardinality = new HashMap<>();
|
||||
overallCardinality.put("part", 100L);
|
||||
overallCardinality.put("buy", 200L);
|
||||
overallCardinality.put("ovr", 300L);
|
||||
|
||||
String function = randomFrom("mean", "min", "max", "sum");
|
||||
|
||||
Detector noSplit = createDetector(function, "field", null, null, null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(noSplit,
|
||||
overallCardinality), is(65536L));
|
||||
|
||||
Detector withByField = createDetector(function, "field", "buy", null, null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withByField,
|
||||
overallCardinality), is(200 * 65536L));
|
||||
|
||||
Detector withPartitionField = createDetector(function, "field", null, null, "part");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withPartitionField,
|
||||
overallCardinality), is(100 * 65536L));
|
||||
|
||||
Detector withByAndPartitionFields = createDetector(function, "field", "buy", null, "part");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withByAndPartitionFields,
|
||||
overallCardinality), is(200 * 100 * 65536L));
|
||||
}
|
||||
|
||||
public void testCalculateInfluencerRequirementBytes() {
|
||||
|
||||
Map<String, Long> maxBucketCardinality = new HashMap<>();
|
||||
maxBucketCardinality.put("part", 100L);
|
||||
maxBucketCardinality.put("inf1", 200L);
|
||||
maxBucketCardinality.put("inf2", 300L);
|
||||
|
||||
AnalysisConfig noInfluencers = createCountAnalysisConfig(null, null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(noInfluencers,
|
||||
maxBucketCardinality), is(0L));
|
||||
|
||||
AnalysisConfig influencerAlsoPartitionField = createCountAnalysisConfig(null, "part", "part");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(influencerAlsoPartitionField,
|
||||
maxBucketCardinality), is(0L));
|
||||
|
||||
AnalysisConfig influencerNotPartitionField = createCountAnalysisConfig(null, "part", "inf1");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(influencerNotPartitionField,
|
||||
maxBucketCardinality), is(200 * TransportEstimateModelMemoryAction.BYTES_PER_INFLUENCER_VALUE));
|
||||
|
||||
AnalysisConfig otherInfluencerAsWellAsPartitionField = createCountAnalysisConfig(null, "part", "part", "inf1");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(otherInfluencerAsWellAsPartitionField,
|
||||
maxBucketCardinality), is(200 * TransportEstimateModelMemoryAction.BYTES_PER_INFLUENCER_VALUE));
|
||||
|
||||
AnalysisConfig twoInfluencersNotPartitionField = createCountAnalysisConfig(null, "part", "part", "inf1", "inf2");
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateInfluencerRequirementBytes(twoInfluencersNotPartitionField,
|
||||
maxBucketCardinality), is((200 + 300) * TransportEstimateModelMemoryAction.BYTES_PER_INFLUENCER_VALUE));
|
||||
}
|
||||
|
||||
public void testCalculateCategorizationRequirementBytes() {
|
||||
|
||||
AnalysisConfig analysisConfigWithoutCategorization = createCountAnalysisConfig(null, null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithoutCategorization), is(0L));
|
||||
|
||||
AnalysisConfig analysisConfigWithCategorization = createCountAnalysisConfig(randomAlphaOfLength(10), null);
|
||||
assertThat(TransportEstimateModelMemoryAction.calculateCategorizationRequirementBytes(analysisConfigWithCategorization), is(1L));
|
||||
}
|
||||
|
||||
public void testRoundUpToNextMb() {
|
||||
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(0),
|
||||
equalTo(new ByteSizeValue(0, ByteSizeUnit.BYTES)));
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(1),
|
||||
equalTo(new ByteSizeValue(1, ByteSizeUnit.MB)));
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(randomIntBetween(1, 1024 * 1024)),
|
||||
equalTo(new ByteSizeValue(1, ByteSizeUnit.MB)));
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(1024 * 1024),
|
||||
equalTo(new ByteSizeValue(1, ByteSizeUnit.MB)));
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(1024 * 1024 + 1),
|
||||
equalTo(new ByteSizeValue(2, ByteSizeUnit.MB)));
|
||||
assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(2 * 1024 * 1024),
|
||||
equalTo(new ByteSizeValue(2, ByteSizeUnit.MB)));
|
||||
}
|
||||
|
||||
public static Detector createDetector(String function, String fieldName, String byFieldName,
|
||||
String overFieldName, String partitionFieldName) {
|
||||
|
||||
Detector.Builder detectorBuilder = new Detector.Builder(function, fieldName);
|
||||
detectorBuilder.setByFieldName(byFieldName);
|
||||
detectorBuilder.setOverFieldName(overFieldName);
|
||||
detectorBuilder.setPartitionFieldName(partitionFieldName);
|
||||
return detectorBuilder.build();
|
||||
}
|
||||
|
||||
public static AnalysisConfig createCountAnalysisConfig(String categorizationFieldName, String partitionFieldName,
|
||||
String... influencerFieldNames) {
|
||||
|
||||
Detector.Builder detectorBuilder = new Detector.Builder("count", null);
|
||||
detectorBuilder.setPartitionFieldName((categorizationFieldName != null) ? AnalysisConfig.ML_CATEGORY_FIELD : partitionFieldName);
|
||||
|
||||
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()));
|
||||
|
||||
if (categorizationFieldName != null) {
|
||||
builder.setCategorizationFieldName(categorizationFieldName);
|
||||
}
|
||||
|
||||
if (influencerFieldNames.length > 0) {
|
||||
builder.setInfluencers(Arrays.asList(influencerFieldNames));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"ml.estimate_model_memory":{
|
||||
"documentation":{
|
||||
"url":null
|
||||
},
|
||||
"stability":"stable",
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/anomaly_detectors/_estimate_model_memory",
|
||||
"methods":[
|
||||
"POST"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"params":{},
|
||||
"body":{
|
||||
"description":"The analysis config, plus cardinality estimates for fields it references",
|
||||
"required":true
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
---
|
||||
"Test by field":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline"}]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3135mb" }
|
||||
|
||||
---
|
||||
"Test by field also influencer":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline"}],
|
||||
"influencers": [ "airline" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"airline": 500
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3135mb" }
|
||||
|
||||
---
|
||||
"Test by field with independent influencer":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline"}],
|
||||
"influencers": [ "country" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"country": 500
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3140mb" }
|
||||
|
||||
---
|
||||
"Test partition field":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "partition_field_name": "airline"}]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3135mb" }
|
||||
|
||||
---
|
||||
"Test partition field also influencer":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "partition_field_name": "airline"}],
|
||||
"influencers": [ "airline" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"airline": 500
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3135mb" }
|
||||
|
||||
---
|
||||
"Test partition field with independent influencer":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "partition_field_name": "airline"}],
|
||||
"influencers": [ "country" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 50000
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"country": 500
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "3140mb" }
|
||||
|
||||
---
|
||||
"Test by and partition field":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline", "partition_field_name": "country"}]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 4000,
|
||||
"country": 600
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "150010mb" }
|
||||
|
||||
---
|
||||
"Test by and partition fields also influencers":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline", "partition_field_name": "country"}],
|
||||
"influencers": [ "airline", "country" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 4000,
|
||||
"country": 600
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"airline": 60,
|
||||
"country": 40
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "150010mb" }
|
||||
|
||||
---
|
||||
"Test by and partition fields with independent influencer":
|
||||
- do:
|
||||
ml.estimate_model_memory:
|
||||
body: >
|
||||
{
|
||||
"analysis_config": {
|
||||
"bucket_span": "1h",
|
||||
"detectors": [{"function": "max", "field_name": "responsetime", "by_field_name": "airline", "partition_field_name": "country"}],
|
||||
"influencers": [ "src_ip" ]
|
||||
},
|
||||
"overall_cardinality": {
|
||||
"airline": 4000,
|
||||
"country": 600
|
||||
},
|
||||
"max_bucket_cardinality": {
|
||||
"src_ip": 500
|
||||
}
|
||||
}
|
||||
- match: { model_memory_estimate: "150015mb" }
|
||||
|
Loading…
Reference in New Issue