[ML] add max_model_memory parameter to forecast request (#57254) (#57355)

This adds a max_model_memory setting to forecast requests. 
This setting can take a string value that is formatted according to byte sizes (i.e. "50mb", "150mb").

The default value is `20mb`.

There is a HARD limit at `500mb` which will throw an error if used.

If the limit is larger than 40% the anomaly job's configured model limit, the forecast limit is reduced to be strictly lower than that value. This reduction is logged and audited.

related native change: https://github.com/elastic/ml-cpp/pull/1238

closes: https://github.com/elastic/elasticsearch/issues/56420
This commit is contained in:
Benjamin Trent 2020-05-29 11:16:08 -04:00 committed by GitHub
parent e4fd78f866
commit c8374dc9f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 320 additions and 12 deletions

View File

@ -22,11 +22,15 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
@ -38,6 +42,7 @@ public class ForecastJobRequest extends ActionRequest implements ToXContentObjec
public static final ParseField DURATION = new ParseField("duration");
public static final ParseField EXPIRES_IN = new ParseField("expires_in");
public static final ParseField MAX_MODEL_MEMORY = new ParseField("max_model_memory");
public static final ConstructingObjectParser<ForecastJobRequest, Void> PARSER =
new ConstructingObjectParser<>("forecast_job_request", (a) -> new ForecastJobRequest((String)a[0]));
@ -48,11 +53,20 @@ public class ForecastJobRequest extends ActionRequest implements ToXContentObjec
(request, val) -> request.setDuration(TimeValue.parseTimeValue(val, DURATION.getPreferredName())), DURATION);
PARSER.declareString(
(request, val) -> request.setExpiresIn(TimeValue.parseTimeValue(val, EXPIRES_IN.getPreferredName())), EXPIRES_IN);
PARSER.declareField(ForecastJobRequest::setMaxModelMemory, (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ByteSizeValue.parseBytesSizeValue(p.text(), MAX_MODEL_MEMORY.getPreferredName());
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return new ByteSizeValue(p.longValue());
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, MAX_MODEL_MEMORY, ObjectParser.ValueType.VALUE);
}
private final String jobId;
private TimeValue duration;
private TimeValue expiresIn;
private ByteSizeValue maxModelMemory;
/**
* A new forecast request
@ -100,9 +114,25 @@ public class ForecastJobRequest extends ActionRequest implements ToXContentObjec
this.expiresIn = expiresIn;
}
public ByteSizeValue getMaxModelMemory() {
return maxModelMemory;
}
/**
* Set the amount of memory allowed to be used by this forecast.
*
* If the projected forecast memory usage exceeds this amount, the forecast will spool results to disk to keep within the limits.
* @param maxModelMemory A byte sized value less than 500MB and less than 40% of the associated job's configured memory usage.
* Defaults to 20MB.
*/
public ForecastJobRequest setMaxModelMemory(ByteSizeValue maxModelMemory) {
this.maxModelMemory = maxModelMemory;
return this;
}
@Override
public int hashCode() {
return Objects.hash(jobId, duration, expiresIn);
return Objects.hash(jobId, duration, expiresIn, maxModelMemory);
}
@Override
@ -116,7 +146,8 @@ public class ForecastJobRequest extends ActionRequest implements ToXContentObjec
ForecastJobRequest other = (ForecastJobRequest) obj;
return Objects.equals(jobId, other.jobId)
&& Objects.equals(duration, other.duration)
&& Objects.equals(expiresIn, other.expiresIn);
&& Objects.equals(expiresIn, other.expiresIn)
&& Objects.equals(maxModelMemory, other.maxModelMemory);
}
@Override
@ -129,6 +160,9 @@ public class ForecastJobRequest extends ActionRequest implements ToXContentObjec
if (expiresIn != null) {
builder.field(EXPIRES_IN.getPreferredName(), expiresIn.getStringRep());
}
if (maxModelMemory != null) {
builder.field(MAX_MODEL_MEMORY.getPreferredName(), maxModelMemory.getStringRep());
}
builder.endObject();
return builder;
}

View File

@ -1506,6 +1506,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::forecast-job-request-options
forecastJobRequest.setExpiresIn(TimeValue.timeValueHours(48)); // <1>
forecastJobRequest.setDuration(TimeValue.timeValueHours(24)); // <2>
forecastJobRequest.setMaxModelMemory(new ByteSizeValue(30, ByteSizeUnit.MB)); // <3>
// end::forecast-job-request-options
// tag::forecast-job-execute

View File

@ -18,6 +18,8 @@
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -36,6 +38,11 @@ public class ForecastJobRequestTests extends AbstractXContentTestCase<ForecastJo
if (randomBoolean()) {
request.setDuration(TimeValue.timeValueHours(randomIntBetween(24, 72)));
}
if (randomBoolean()) {
request.setMaxModelMemory(new ByteSizeValue(randomLongBetween(
new ByteSizeValue(1, ByteSizeUnit.MB).getBytes(),
new ByteSizeValue(499, ByteSizeUnit.MB).getBytes())));
}
return request;
}

View File

@ -34,6 +34,10 @@ include-tagged::{doc-tests-file}[{api}-request-options]
--------------------------------------------------
<1> Set when the forecast for the job should expire
<2> Set how far into the future should the forecast predict
<3> Set the maximum amount of memory the forecast is allowed to use.
Defaults to 20mb. Maximum is 500mb, minimum is 1mb. If set to
40% or more of the job's configured memory limit, it is
automatically reduced to below that number.
[id="{upid}-{api}-response"]
==== Forecast Job Response

View File

@ -62,6 +62,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=job-id-anomaly-detection]
default value is 14 days. If set to a value of `0`, the forecast is never
automatically deleted.
`max_model_memory`::
(Optional, <<byte-units,byte value>>) The maximum memory the forecast can use.
If the forecast needs to use more than the provided amount, it will spool to
disk. Default is 20mb, maximum is 500mb and minimum is 1mb. If set to 40% or
more of the job's configured memory limit, it is automatically reduced to
below that amount.
[[ml-forecast-example]]
==== {api-examples-title}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
@ -13,13 +14,17 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.results.Forecast;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
@ -37,9 +42,13 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
public static final ParseField DURATION = new ParseField("duration");
public static final ParseField EXPIRES_IN = new ParseField("expires_in");
public static final ParseField MAX_MODEL_MEMORY = new ParseField("max_model_memory");
public static final ByteSizeValue FORECAST_LOCAL_STORAGE_LIMIT = new ByteSizeValue(500, ByteSizeUnit.MB);
// Max allowed duration: 10 years
private static final TimeValue MAX_DURATION = TimeValue.parseTimeValue("3650d", "");
private static final long MIN_MODEL_MEMORY = new ByteSizeValue(1, ByteSizeUnit.MB).getBytes();
private static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
@ -47,6 +56,14 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
PARSER.declareString((request, jobId) -> request.jobId = jobId, Job.ID);
PARSER.declareString(Request::setDuration, DURATION);
PARSER.declareString(Request::setExpiresIn, EXPIRES_IN);
PARSER.declareField(Request::setMaxModelMemory, (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ByteSizeValue.parseBytesSizeValue(p.text(), MAX_MODEL_MEMORY.getPreferredName()).getBytes();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.longValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, MAX_MODEL_MEMORY, ObjectParser.ValueType.VALUE);
}
public static Request parseRequest(String jobId, XContentParser parser) {
@ -59,6 +76,7 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
private TimeValue duration;
private TimeValue expiresIn;
private Long maxModelMemory;
public Request() {
}
@ -67,6 +85,9 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
super(in);
this.duration = in.readOptionalTimeValue();
this.expiresIn = in.readOptionalTimeValue();
if (in.getVersion().onOrAfter(Version.V_7_9_0)) {
this.maxModelMemory = in.readOptionalVLong();
}
}
@Override
@ -74,6 +95,9 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
super.writeTo(out);
out.writeOptionalTimeValue(duration);
out.writeOptionalTimeValue(expiresIn);
if (out.getVersion().onOrAfter(Version.V_7_9_0)) {
out.writeOptionalVLong(maxModelMemory);
}
}
public Request(String jobId) {
@ -116,9 +140,26 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
}
}
public void setMaxModelMemory(long numBytes) {
if (numBytes < MIN_MODEL_MEMORY) {
throw new IllegalArgumentException("[" + MAX_MODEL_MEMORY.getPreferredName() + "] must be at least 1mb.");
}
if (numBytes >= FORECAST_LOCAL_STORAGE_LIMIT.getBytes()) {
throw ExceptionsHelper.badRequestException(
"[{}] must be less than {}",
MAX_MODEL_MEMORY.getPreferredName(),
FORECAST_LOCAL_STORAGE_LIMIT.getStringRep());
}
this.maxModelMemory = numBytes;
}
public Long getMaxModelMemory() {
return maxModelMemory;
}
@Override
public int hashCode() {
return Objects.hash(jobId, duration, expiresIn);
return Objects.hash(jobId, duration, expiresIn, maxModelMemory);
}
@Override
@ -132,7 +173,8 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
Request other = (Request) obj;
return Objects.equals(jobId, other.jobId)
&& Objects.equals(duration, other.duration)
&& Objects.equals(expiresIn, other.expiresIn);
&& Objects.equals(expiresIn, other.expiresIn)
&& Objects.equals(maxModelMemory, other.maxModelMemory);
}
@Override
@ -145,6 +187,9 @@ public class ForecastJobAction extends ActionType<ForecastJobAction.Response> {
if (expiresIn != null) {
builder.field(EXPIRES_IN.getPreferredName(), expiresIn.getStringRep());
}
if (maxModelMemory != null) {
builder.field(MAX_MODEL_MEMORY.getPreferredName(), new ByteSizeValue(maxModelMemory).getStringRep());
}
builder.endObject();
return builder;
}

View File

@ -6,6 +6,8 @@
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
@ -34,6 +36,11 @@ public class ForecastJobActionRequestTests extends AbstractSerializingTestCase<R
if (randomBoolean()) {
request.setExpiresIn(TimeValue.timeValueSeconds(randomIntBetween(0, 1_000_000)).getStringRep());
}
if (randomBoolean()) {
request.setMaxModelMemory(randomLongBetween(
new ByteSizeValue(1, ByteSizeUnit.MB).getBytes(),
new ByteSizeValue(499, ByteSizeUnit.MB).getBytes()));
}
return request;
}

View File

@ -10,6 +10,8 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@ -22,7 +24,6 @@ import org.elasticsearch.xpack.core.ml.job.results.Forecast;
import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats;
import org.junit.After;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
@ -379,7 +380,54 @@ public class ForecastIT extends MlNativeAutodetectIntegTestCase {
}
}
private void createDataWithLotsOfClientIps(TimeValue bucketSpan, Job.Builder job) throws IOException {
public void testForecastWithHigherMemoryUse() throws Exception {
Detector.Builder detector = new Detector.Builder("mean", "value");
TimeValue bucketSpan = TimeValue.timeValueHours(1);
AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
analysisConfig.setBucketSpan(bucketSpan);
DataDescription.Builder dataDescription = new DataDescription.Builder();
dataDescription.setTimeFormat("epoch");
Job.Builder job = new Job.Builder("forecast-it-test-single-series");
job.setAnalysisConfig(analysisConfig);
job.setDataDescription(dataDescription);
registerJob(job);
putJob(job);
openJob(job.getId());
long now = Instant.now().getEpochSecond();
long timestamp = now - 50 * bucketSpan.seconds();
List<String> data = new ArrayList<>();
while (timestamp < now) {
data.add(createJsonRecord(createRecord(timestamp, 10.0)));
data.add(createJsonRecord(createRecord(timestamp, 30.0)));
timestamp += bucketSpan.seconds();
}
postData(job.getId(), data.stream().collect(Collectors.joining()));
flushJob(job.getId(), false);
// Now we can start doing forecast requests
String forecastId = forecast(job.getId(),
TimeValue.timeValueHours(1),
TimeValue.ZERO,
new ByteSizeValue(50, ByteSizeUnit.MB).getBytes());
waitForecastToFinish(job.getId(), forecastId);
closeJob(job.getId());
List<ForecastRequestStats> forecastStats = getForecastStats();
ForecastRequestStats forecastDuration1HourNoExpiry = forecastStats.get(0);
assertThat(forecastDuration1HourNoExpiry.getExpiryTime(), equalTo(Instant.EPOCH));
List<Forecast> forecasts = getForecasts(job.getId(), forecastDuration1HourNoExpiry);
assertThat(forecastDuration1HourNoExpiry.getRecordCount(), equalTo(1L));
assertThat(forecasts.size(), equalTo(1));
}
private void createDataWithLotsOfClientIps(TimeValue bucketSpan, Job.Builder job) {
long now = Instant.now().getEpochSecond();
long timestamp = now - 15 * bucketSpan.seconds();

View File

@ -258,6 +258,10 @@ abstract class MlNativeAutodetectIntegTestCase extends MlNativeIntegTestCase {
}
protected String forecast(String jobId, TimeValue duration, TimeValue expiresIn) {
return forecast(jobId, duration, expiresIn, null);
}
protected String forecast(String jobId, TimeValue duration, TimeValue expiresIn, Long maxMemory) {
ForecastJobAction.Request request = new ForecastJobAction.Request(jobId);
if (duration != null) {
request.setDuration(duration.getStringRep());
@ -265,6 +269,9 @@ abstract class MlNativeAutodetectIntegTestCase extends MlNativeIntegTestCase {
if (expiresIn != null) {
request.setExpiresIn(expiresIn.getStringRep());
}
if (maxMemory != null) {
request.setMaxModelMemory(maxMemory);
}
return client().execute(ForecastJobAction.INSTANCE, request).actionGet().getForecastId();
}

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.ml.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
@ -16,7 +18,10 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
import org.elasticsearch.xpack.core.ml.action.ForecastJobAction;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -24,6 +29,7 @@ import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.ForecastParams;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.process.NativeStorageProvider;
import java.nio.file.Path;
@ -31,27 +37,30 @@ import java.util.List;
import java.util.function.Consumer;
import static org.elasticsearch.xpack.core.ml.action.ForecastJobAction.Request.DURATION;
import static org.elasticsearch.xpack.core.ml.action.ForecastJobAction.Request.FORECAST_LOCAL_STORAGE_LIMIT;
public class TransportForecastJobAction extends TransportJobTaskAction<ForecastJobAction.Request,
ForecastJobAction.Response> {
private static final ByteSizeValue FORECAST_LOCAL_STORAGE_LIMIT = new ByteSizeValue(500, ByteSizeUnit.MB);
private static final Logger logger = LogManager.getLogger(TransportForecastJobAction.class);
private final JobResultsProvider jobResultsProvider;
private final JobManager jobManager;
private final NativeStorageProvider nativeStorageProvider;
private final AnomalyDetectionAuditor auditor;
@Inject
public TransportForecastJobAction(TransportService transportService,
ClusterService clusterService, ActionFilters actionFilters,
JobResultsProvider jobResultsProvider, AutodetectProcessManager processManager,
JobManager jobManager, NativeStorageProvider nativeStorageProvider) {
JobManager jobManager, NativeStorageProvider nativeStorageProvider, AnomalyDetectionAuditor auditor) {
super(ForecastJobAction.NAME, clusterService, transportService, actionFilters,
ForecastJobAction.Request::new, ForecastJobAction.Response::new,
ThreadPool.Names.SAME, processManager);
this.jobResultsProvider = jobResultsProvider;
this.jobManager = jobManager;
this.nativeStorageProvider = nativeStorageProvider;
this.auditor = auditor;
// ThreadPool.Names.SAME, because operations is executed by autodetect worker thread
}
@ -72,6 +81,11 @@ public class TransportForecastJobAction extends TransportJobTaskAction<ForecastJ
paramsBuilder.expiresIn(request.getExpiresIn());
}
Long adjustedLimit = getAdjustedMemoryLimit(job, request.getMaxModelMemory(), auditor);
if (adjustedLimit != null) {
paramsBuilder.maxModelMemory(adjustedLimit);
}
// tmp storage might be null, we do not log here, because it might not be
// required
Path tmpStorage = nativeStorageProvider.tryGetLocalTmpStorage(task.getDescription(), FORECAST_LOCAL_STORAGE_LIMIT);
@ -124,6 +138,26 @@ public class TransportForecastJobAction extends TransportJobTaskAction<ForecastJ
jobResultsProvider.getForecastRequestStats(jobId, forecastId, forecastRequestStatsHandler, listener::onFailure);
}
static Long getAdjustedMemoryLimit(Job job, Long requestedLimit, AbstractAuditor<? extends AbstractAuditMessage> auditor) {
if (requestedLimit == null) {
return null;
}
long jobLimitMegaBytes = job.getAnalysisLimits() == null || job.getAnalysisLimits().getModelMemoryLimit() == null ?
AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB :
job.getAnalysisLimits().getModelMemoryLimit();
long allowedMax = (long)(new ByteSizeValue(jobLimitMegaBytes, ByteSizeUnit.MB).getBytes() * 0.40);
long adjustedMax = Math.min(requestedLimit, allowedMax - 1);
if (adjustedMax != requestedLimit) {
String msg = "requested forecast memory limit [" +
requestedLimit +
"] bytes is greater than or equal to [" + allowedMax +
"] bytes (40% of the job memory limit). Reducing to [" + adjustedMax + "].";
logger.warn("[{}] {}", job.getId(), msg);
auditor.warning(job.getId(), msg);
}
return adjustedMax;
}
static void validate(Job job, ForecastJobAction.Request request) {
if (job.getJobVersion() == null || job.getJobVersion().before(Version.V_6_1_0)) {
throw ExceptionsHelper.badRequestException(

View File

@ -17,13 +17,15 @@ public class ForecastParams {
private final long duration;
private final long expiresIn;
private final String tmpStorage;
private final Long maxModelMemory;
private ForecastParams(String forecastId, long createTime, long duration, long expiresIn, String tmpStorage) {
private ForecastParams(String forecastId, long createTime, long duration, long expiresIn, String tmpStorage, Long maxModelMemory) {
this.forecastId = forecastId;
this.createTime = createTime;
this.duration = duration;
this.expiresIn = expiresIn;
this.tmpStorage = tmpStorage;
this.maxModelMemory = maxModelMemory;
}
public String getForecastId() {
@ -63,9 +65,13 @@ public class ForecastParams {
return tmpStorage;
}
public Long getMaxModelMemory() {
return maxModelMemory;
}
@Override
public int hashCode() {
return Objects.hash(forecastId, createTime, duration, expiresIn, tmpStorage);
return Objects.hash(forecastId, createTime, duration, expiresIn, tmpStorage, maxModelMemory);
}
@Override
@ -81,7 +87,8 @@ public class ForecastParams {
&& Objects.equals(createTime, other.createTime)
&& Objects.equals(duration, other.duration)
&& Objects.equals(expiresIn, other.expiresIn)
&& Objects.equals(tmpStorage, other.tmpStorage);
&& Objects.equals(tmpStorage, other.tmpStorage)
&& Objects.equals(maxModelMemory, other.maxModelMemory);
}
public static Builder builder() {
@ -93,6 +100,7 @@ public class ForecastParams {
private final long createTimeEpochSecs;
private long durationSecs;
private long expiresInSecs;
private Long maxModelMemory;
private String tmpStorage;
private Builder() {
@ -119,8 +127,13 @@ public class ForecastParams {
return this;
}
public Builder maxModelMemory(long maxModelMemory) {
this.maxModelMemory = maxModelMemory;
return this;
}
public ForecastParams build() {
return new ForecastParams(forecastId, createTimeEpochSecs, durationSecs, expiresInSecs, tmpStorage);
return new ForecastParams(forecastId, createTimeEpochSecs, durationSecs, expiresInSecs, tmpStorage, maxModelMemory);
}
}
}

View File

@ -158,6 +158,9 @@ public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {
if (params.getTmpStorage() != null) {
builder.field("tmp_storage", params.getTmpStorage());
}
if (params.getMaxModelMemory() != null) {
builder.field("max_model_memory", params.getMaxModelMemory());
}
builder.endObject();
writeMessage(FORECAST_MESSAGE_CODE + Strings.toString(builder));

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.rest.job;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
@ -56,6 +57,13 @@ public class RestForecastJobAction extends BaseRestHandler {
if (restRequest.hasParam(ForecastJobAction.Request.EXPIRES_IN.getPreferredName())) {
request.setExpiresIn(restRequest.param(ForecastJobAction.Request.EXPIRES_IN.getPreferredName()));
}
if (restRequest.hasParam(ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName())) {
long limit = ByteSizeValue.parseBytesSizeValue(
restRequest.param(ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName()),
ForecastJobAction.Request.MAX_MODEL_MEMORY.getPreferredName()
).getBytes();
request.setMaxModelMemory(limit);
}
}
return channel -> client.execute(ForecastJobAction.INSTANCE, request, new RestToXContentListener<>(channel));

View File

@ -7,17 +7,28 @@ package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
import org.elasticsearch.xpack.core.ml.action.ForecastJobAction;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.notifications.AnomalyDetectionAuditMessage;
import java.util.Collections;
import java.util.Date;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
public class TransportForecastJobActionRequestTests extends ESTestCase {
public void testValidate_jobVersionCannonBeBefore61() {
@ -53,6 +64,55 @@ public class TransportForecastJobActionRequestTests extends ESTestCase {
assertEquals("[duration] must be greater or equal to the bucket span: [1m/1h]", e.getMessage());
}
public void testAdjustLimit() {
Job.Builder jobBuilder = createTestJob("forecast-adjust-limit");
NullAuditor auditor = new NullAuditor();
{
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(jobBuilder.build(), null, auditor), is(nullValue()));
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(
jobBuilder.build(),
new ByteSizeValue(20, ByteSizeUnit.MB).getBytes(),
auditor),
equalTo(new ByteSizeValue(20, ByteSizeUnit.MB).getBytes()));
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(
jobBuilder.build(),
new ByteSizeValue(499, ByteSizeUnit.MB).getBytes(),
auditor),
equalTo(new ByteSizeValue(499, ByteSizeUnit.MB).getBytes()));
}
{
long limit = new ByteSizeValue(100, ByteSizeUnit.MB).getBytes();
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(
jobBuilder.setAnalysisLimits(new AnalysisLimits(1L)).build(),
limit,
auditor),
equalTo(104857600L));
}
{
long limit = 429496732L;
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(
jobBuilder.setAnalysisLimits(new AnalysisLimits(1L)).build(),
limit,
auditor),
equalTo(429496728L));
}
{
long limit = new ByteSizeValue(200, ByteSizeUnit.MB).getBytes();
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(jobBuilder.build(), limit, auditor), equalTo(limit));
// gets adjusted down due to job analysis limits
assertThat(TransportForecastJobAction.getAdjustedMemoryLimit(
jobBuilder.setAnalysisLimits(new AnalysisLimits(200L, null)).build(),
limit,
auditor),
equalTo(new ByteSizeValue(80, ByteSizeUnit.MB).getBytes() - 1L));
}
}
private Job.Builder createTestJob(String jobId) {
Job.Builder jobBuilder = new Job.Builder(jobId);
jobBuilder.setCreateTime(new Date());
@ -66,4 +126,23 @@ public class TransportForecastJobActionRequestTests extends ESTestCase {
jobBuilder.setDataDescription(dataDescription);
return jobBuilder;
}
static class NullAuditor extends AbstractAuditor<AnomalyDetectionAuditMessage> {
protected NullAuditor() {
super(mock(Client.class), "test", "null", "foo", AnomalyDetectionAuditMessage::new);
}
@Override
public void info(String resourceId, String message) {
}
@Override
public void warning(String resourceId, String message) {
}
@Override
public void error(String resourceId, String message) {
}
}
}

View File

@ -31,6 +31,11 @@
"type":"time",
"required":false,
"description":"The time interval after which the forecast expires. Expired forecasts will be deleted at the first opportunity."
},
"max_model_memory":{
"type":"string",
"required":false,
"description":"The max memory able to be used by the forecast. Default is 20mb."
}
}
}

View File

@ -62,3 +62,10 @@ setup:
ml.forecast:
job_id: "forecast-job"
expires_in: "-1s"
---
"Test forecast given max_model_memory is too large":
- do:
catch: /\[max_model_memory\] must be less than 500mb/
ml.forecast:
job_id: "forecast-job"
max_model_memory: "1000mb"