[7.x] [ML][Inference] PUT API (#50852) (#50887)

* [ML][Inference] PUT API (#50852)

This adds the `PUT` API for creating trained models that support our format.

This includes

* HLRC change for the API
* API creation
* Validations of model format and call

* fixing backport
This commit is contained in:
Benjamin Trent 2020-01-12 10:59:11 -05:00 committed by GitHub
parent 456de59698
commit fa116a6d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1648 additions and 423 deletions

View File

@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@ -792,6 +793,16 @@ final class MLRequestConverters {
return new Request(HttpDelete.METHOD_NAME, endpoint);
}
static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "inference")
.addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
.build();
Request request = new Request(HttpPut.METHOD_NAME, endpoint);
request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
}
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")

View File

@ -100,6 +100,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@ -2340,6 +2342,48 @@ public final class MachineLearningClient {
Collections.emptySet());
}
/**
* Put trained model config
* <p>
* For additional info
* see <a href="TODO">
* PUT Trained Model Config documentation</a>
*
* @param request The {@link PutTrainedModelRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return {@link PutTrainedModelResponse} response object
*/
public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::putTrainedModel,
options,
PutTrainedModelResponse::fromXContent,
Collections.emptySet());
}
/**
* Put trained model config asynchronously and notifies listener upon completion
* <p>
* For additional info
* see <a href="TODO">
* PUT Trained Model Config documentation</a>
*
* @param request The {@link PutTrainedModelRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
RequestOptions options,
ActionListener<PutTrainedModelResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::putTrainedModel,
options,
PutTrainedModelResponse::fromXContent,
listener,
Collections.emptySet());
}
/**
* Gets trained model stats
* <p>

View File

@ -0,0 +1,66 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
public class PutTrainedModelRequest implements Validatable, ToXContentObject {
private final TrainedModelConfig config;
public PutTrainedModelRequest(TrainedModelConfig config) {
this.config = config;
}
public TrainedModelConfig getTrainedModelConfig() {
return config;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
return config.toXContent(builder, params);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PutTrainedModelRequest request = (PutTrainedModelRequest) o;
return Objects.equals(config, request.config);
}
@Override
public int hashCode() {
return Objects.hash(config);
}
@Override
public final String toString() {
return Strings.toString(config);
}
}

View File

@ -0,0 +1,63 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
public class PutTrainedModelResponse implements ToXContentObject {
private final TrainedModelConfig trainedModelConfig;
public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
}
public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
this.trainedModelConfig = trainedModelConfig;
}
public TrainedModelConfig getResponse() {
return trainedModelConfig;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return trainedModelConfig.toXContent(builder, params);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PutTrainedModelResponse response = (PutTrainedModelResponse) o;
return Objects.equals(trainedModelConfig, response.trainedModelConfig);
}
@Override
public int hashCode() {
return Objects.hash(trainedModelConfig);
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
*/
public final class InferenceToXContentCompressor {
private static final int BUFFER_SIZE = 4096;
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
private InferenceToXContentCompressor() {}
public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
return deflate(reference);
}
public static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
inflate(compressedString, MAX_INFLATED_BYTES),
XContentType.JSON)) {
return parserFunction.apply(parser);
}
}
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
return Streams.readFully(inflateStream);
}
private static String deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput);
}
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import java.io.IOException;
import java.io.InputStream;
import java.util.Objects;
/**
* This is a pared down bounded input stream.
* Only read is specifically enforced.
*/
final class SimpleBoundedInputStream extends InputStream {
private final InputStream in;
private final long maxBytes;
private long numBytes;
SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
this.in = Objects.requireNonNull(inputStream, "inputStream");
if (maxBytes < 0) {
throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
}
this.maxBytes = maxBytes;
}
/**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
* @throws IOException on failure
*/
@Override
public int read() throws IOException {
// We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) {
return -1;
}
numBytes++;
return in.read();
}
/**
* Delegates `close` to the wrapped InputStream
* @throws IOException on failure
*/
@Override
public void close() throws IOException {
in.close();
}
}

View File

@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -111,7 +112,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
this.definition = definition;
this.compressedDefinition = compressedDefinition;
this.description = description;
@ -293,12 +294,12 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setCreatedBy(String createdBy) {
private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
public Builder setVersion(Version version) {
private Builder setVersion(Version version) {
this.version = version;
return this;
}
@ -312,7 +313,7 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setCreateTime(Instant createTime) {
private Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
@ -322,6 +323,10 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setTags(String... tags) {
return setTags(Arrays.asList(tags));
}
public Builder setMetadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
@ -347,17 +352,17 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
public Builder setEstimatedOperations(Long estimatedOperations) {
private Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
public Builder setLicenseLevel(String licenseLevel) {
private Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@ -48,6 +49,10 @@ public class TrainedModelInput implements ToXContentObject {
this.fieldNames = fieldNames;
}
public TrainedModelInput(String... fieldNames) {
this(Arrays.asList(fieldNames));
}
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

View File

@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@ -91,6 +92,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.Detector;
import org.elasticsearch.client.ml.job.config.Job;
@ -874,6 +878,20 @@ public class MLRequestConvertersTests extends ESTestCase {
assertNull(request.getEntity());
}
public void testPutTrainedModel() throws IOException {
TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig();
PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig);
Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest);
assertEquals(HttpPut.METHOD_NAME, request.getMethod());
assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId()));
try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build();
assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig));
}
}
public void testPutFilter() throws IOException {
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
@ -1046,6 +1064,7 @@ public class MLRequestConvertersTests extends ESTestCase {
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}

View File

@ -101,6 +101,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@ -146,9 +148,12 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
@ -162,14 +167,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter;
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@ -178,11 +181,9 @@ import org.elasticsearch.search.SearchHit;
import org.junit.After;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -190,7 +191,6 @@ import java.util.Locale;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
@ -2222,6 +2222,50 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
}
}
public void testPutTrainedModel() throws Exception {
String modelId = "test-put-trained-model";
String modelIdCompressed = "test-put-trained-model-compressed-definition";
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
machineLearningClient::putTrainedModel,
machineLearningClient::putTrainedModelAsync);
TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
assertThat(createdModel.getModelId(), equalTo(modelId));
definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
trainedModelConfig = TrainedModelConfig.builder()
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
.setModelId(modelIdCompressed)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
machineLearningClient::putTrainedModel,
machineLearningClient::putTrainedModelAsync);
createdModel = putTrainedModelResponse.getResponse();
assertThat(createdModel.getModelId(), equalTo(modelIdCompressed));
GetTrainedModelsResponse getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true),
machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync);
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed));
}
public void testGetTrainedModelsStats() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String modelIdPrefix = "a-get-trained-model-stats-";
@ -2504,56 +2548,13 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
private void putTrainedModel(String modelId) throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
highLevelClient().index(
new IndexRequest(".ml-inference-000001")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.source(modelConfigString(modelId), XContentType.JSON)
.id(modelId),
RequestOptions.DEFAULT);
highLevelClient().index(
new IndexRequest(".ml-inference-000001")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
.id("trained_model_definition_doc-" + modelId + "-0"),
RequestOptions.DEFAULT);
}
private String compressDefinition(TrainedModelDefinition definition) throws IOException {
BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
reference.writeTo(compressedOutput);
}
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
}
private static String modelConfigString(String modelId) {
return "{\n" +
" \"doc_type\": \"trained_model_config\",\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
"}";
}
private static String modelDocString(String compressedDefinition, String modelId) {
return "" +
"{" +
"\"model_id\": \"" + modelId + "\",\n" +
"\"doc_num\": 0,\n" +
"\"doc_type\": \"trained_model_definition_doc\",\n" +
" \"compression_version\": " + 1 + ",\n" +
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
"\"definition\": \"" + compressedDefinition + "\"\n" +
"}";
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private void waitForJobToClose(String jobId) throws Exception {
@ -2798,4 +2799,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT);
assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
import org.elasticsearch.client.ml.job.config.DataDescription;
@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId;
import org.junit.After;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
@ -216,7 +219,6 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
}
public void testPutTrainedModel() throws Exception {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
// tag::put-trained-model-config
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition) // <1>
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
.setModelId("my-new-trained-model") // <3>
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
.setDescription("test model") // <5>
.setMetadata(new HashMap<>()) // <6>
.setTags("my_regression_models") // <7>
.build();
// end::put-trained-model-config
trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId("my-new-trained-model")
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
.setDescription("test model")
.setMetadata(new HashMap<>())
.setTags("my_regression_models")
.build();
RestHighLevelClient client = highLevelClient();
{
// tag::put-trained-model-request
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
// end::put-trained-model-request
// tag::put-trained-model-execute
PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
// end::put-trained-model-execute
// tag::put-trained-model-response
TrainedModelConfig model = response.getResponse();
// end::put-trained-model-response
assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
highLevelClient().machineLearning()
.deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
}
{
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
// tag::put-trained-model-execute-listener
ActionListener<PutTrainedModelResponse> listener = new ActionListener<PutTrainedModelResponse>() {
@Override
public void onResponse(PutTrainedModelResponse response) {
// <1>
}
@Override
public void onFailure(Exception e) {
// <2>
}
};
// end::put-trained-model-execute-listener
// Replace the empty listener by a blocking listener in test
CountDownLatch latch = new CountDownLatch(1);
listener = new LatchedActionListener<>(listener, latch);
// tag::put-trained-model-execute-async
client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
// end::put-trained-model-execute-async
assertTrue(latch.await(30L, TimeUnit.SECONDS));
highLevelClient().machineLearning()
.deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
}
}
public void testGetTrainedModelsStats() throws Exception {
putTrainedModel("my-trained-model");
RestHighLevelClient client = highLevelClient();
@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
private void putTrainedModel(String modelId) throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
highLevelClient().index(
new IndexRequest(".ml-inference-000001")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.source(modelConfigString(modelId), XContentType.JSON)
.id(modelId),
RequestOptions.DEFAULT);
highLevelClient().index(
new IndexRequest(".ml-inference-000001")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
.id("trained_model_definition_doc-" + modelId + "-0"),
RequestOptions.DEFAULT);
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private String compressDefinition(TrainedModelDefinition definition) throws IOException {
BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
reference.writeTo(compressedOutput);
}
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
}
private static String modelConfigString(String modelId) {
return "{\n" +
" \"doc_type\": \"trained_model_config\",\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
"}";
}
private static String modelDocString(String compressedDefinition, String modelId) {
return "" +
"{" +
"\"model_id\": \"" + modelId + "\",\n" +
"\"doc_num\": 0,\n" +
"\"doc_type\": \"trained_model_definition_doc\",\n" +
" \"compression_version\": " + 1 + ",\n" +
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
"\"definition\": \"" + compressedDefinition + "\"\n" +
"}";
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =

View File

@ -0,0 +1,52 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase<PutTrainedModelRequest> {
@Override
protected PutTrainedModelRequest createTestInstance() {
return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig());
}
@Override
protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException {
return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build());
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase<PutTrainedModelResponse> {
@Override
protected PutTrainedModelResponse createTestInstance() {
return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig());
}
@Override
protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException {
return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build());
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class InferenceToXContentCompressorTests extends ESTestCase {
public void testInflateAndDeflate() throws IOException {
for(int i = 0; i < 10; i++) {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
parser -> TrainedModelDefinition.fromXContent(parser).build(),
xContentRegistry());
// Did we inflate to the same object?
assertThat(inflatedDefinition, equalTo(definition));
}
}
public void testInflateTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
assertThat(inflatedBytes.length(), equalTo(10));
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
LoggingDeprecationHandler.INSTANCE,
inflatedBytes,
XContentType.JSON)) {
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser));
}
}
public void testInflateGarbage() {
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -37,6 +37,24 @@ import java.util.stream.Stream;
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
public static TrainedModelConfig createTestTrainedModelConfig() {
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : randomAlphaOfLength(100),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomFrom("platinum", "basic"));
}
@Override
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
return TrainedModelConfig.fromXContent(parser);
@ -54,22 +72,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
@Override
protected TrainedModelConfig createTestInstance() {
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : randomAlphaOfLength(100),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomFrom("platinum", "basic"));
return createTestTrainedModelConfig();
}
@Override

View File

@ -67,15 +67,17 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
.collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
.limit(numberOfFeatures)
.limit(numberOfModels)
.collect(Collectors.toList());
OutputAggregator outputAggregator = null;
if (randomBoolean()) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights));
List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
new LogisticRegression(weights)));
if (targetType.equals(TargetType.REGRESSION)) {
possibleAggregators.add(new WeightedSum(weights));
}
OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
List<String> categoryLabels = null;
if (randomBoolean()) {
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return new Ensemble(featureNames,

View File

@ -84,7 +84,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
childNodes = nextNodes;
}
List<String> categoryLabels = null;
if (randomBoolean()) {
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.setClassificationLabels(categoryLabels)

View File

@ -0,0 +1,53 @@
--
:api: put-trained-model
:request: PutTrainedModelRequest
:response: PutTrainedModelResponse
--
[role="xpack"]
[id="{upid}-{api}"]
=== Put Trained Model API
Creates a new trained model for inference.
The API accepts a +{request}+ object as a request and returns a +{response}+.
[id="{upid}-{api}-request"]
==== Put Trained Model request
A +{request}+ requires the following argument:
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-request]
--------------------------------------------------
<1> The configuration of the {infer} Trained Model to create
[id="{upid}-{api}-config"]
==== Trained Model configuration
The `TrainedModelConfig` object contains all the details about the trained model
configuration and contains the following arguments:
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-config]
--------------------------------------------------
<1> The {infer} definition for the model
<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport.
Do not supply both the compressed and uncompressed definitions.
<3> The unique model id
<4> The input field names for the model definition
<5> Optionally, a human-readable description
<6> Optionally, an object map contain metadata about the model
<7> Optionally, an array of tags to organize the model
include::../execution.asciidoc[]
[id="{upid}-{api}-response"]
==== Response
The returned +{response}+ contains the newly created trained model.
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-response]
--------------------------------------------------

View File

@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
* <<{upid}-evaluate-data-frame>>
* <<{upid}-explain-data-frame-analytics>>
* <<{upid}-get-trained-models>>
* <<{upid}-put-trained-model>>
* <<{upid}-get-trained-models-stats>>
* <<{upid}-delete-trained-model>>
* <<{upid}-put-filter>>
@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
include::ml/evaluate-data-frame.asciidoc[]
include::ml/explain-data-frame-analytics.asciidoc[]
include::ml/get-trained-models.asciidoc[]
include::ml/put-trained-model.asciidoc[]
include::ml/get-trained-models-stats.asciidoc[]
include::ml/delete-trained-model.asciidoc[]
include::ml/put-filter.asciidoc[]

View File

@ -126,6 +126,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@ -381,6 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,
PutTrainedModelAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,

View File

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.io.IOException;
import java.util.Objects;
public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
public static final String NAME = "cluster:monitor/xpack/ml/inference/put";
private PutTrainedModelAction() {
super(NAME, Response::new);
}
public static class Request extends AcknowledgedRequest<Request> {
public static Request parseRequest(String modelId, XContentParser parser) {
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
if (builder.getModelId() == null) {
builder.setModelId(modelId).build();
} else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) {
// If we have model_id in both URI and body, they must be identical
throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID,
TrainedModelConfig.MODEL_ID.getPreferredName(),
builder.getModelId(),
modelId));
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
return new Request(builder.validate(true).build());
}
private final TrainedModelConfig config;
public Request(TrainedModelConfig config) {
this.config = config;
}
public Request(StreamInput in) throws IOException {
super(in);
this.config = new TrainedModelConfig(in);
}
public TrainedModelConfig getTrainedModelConfig() {
return config;
}
@Override
public ActionRequestValidationException validate() {
return null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(config, request.config);
}
@Override
public int hashCode() {
return Objects.hash(config);
}
@Override
public final String toString() {
return Strings.toString(config);
}
}
public static class Response extends ActionResponse implements ToXContentObject {
private final TrainedModelConfig trainedModelConfig;
public Response(TrainedModelConfig trainedModelConfig) {
this.trainedModelConfig = trainedModelConfig;
}
public Response(StreamInput in) throws IOException {
super(in);
trainedModelConfig = new TrainedModelConfig(in);
}
public TrainedModelConfig getResponse() {
return trainedModelConfig;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
trainedModelConfig.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return trainedModelConfig.toXContent(builder, params);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return Objects.equals(trainedModelConfig, response.trainedModelConfig);
}
@Override
public int hashCode() {
return Objects.hash(trainedModelConfig);
}
}
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
@ -34,6 +35,9 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.action.ValidateActions.addValidationError;
public class TrainedModelConfig implements ToXContentObject, Writeable {
@ -352,13 +356,31 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Long estimatedHeapMemory;
private Long estimatedOperations;
private LazyModelDefinition definition;
private String licenseLevel = License.OperationMode.PLATINUM.description();
private String licenseLevel;
public Builder() {}
public Builder(TrainedModelConfig config) {
this.modelId = config.getModelId();
this.createdBy = config.getCreatedBy();
this.version = config.getVersion();
this.createTime = config.getCreateTime();
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
this.description = config.getDescription();
this.tags = config.getTags();
this.metadata = config.getMetadata();
this.input = config.getInput();
}
public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}
public String getModelId() {
return this.modelId;
}
public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
@ -466,51 +488,96 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
// TODO move to REST level instead of here in the builder
public void validate() {
public Builder validate() {
return validate(false);
}
/**
* Runs validations against the builder.
* @return The current builder object if validations are successful
* @throws ActionRequestValidationException when there are validation failures.
*/
public Builder validate(boolean forCreation) {
// We require a definition to be available here even though it will be stored in a different doc
ExceptionsHelper.requireNonNull(definition, DEFINITION);
ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
if (MlStrings.isValidId(modelId) == false) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
ActionRequestValidationException validationException = null;
if (definition == null) {
validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException);
}
if (modelId == null) {
validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
}
if (MlStrings.hasValidLengthForId(modelId) == false) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
MODEL_ID.getPreferredName(),
if (modelId != null && MlStrings.isValidId(modelId) == false) {
validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
TrainedModelConfig.MODEL_ID.getPreferredName(),
modelId),
validationException);
}
if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) {
validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG,
TrainedModelConfig.MODEL_ID.getPreferredName(),
modelId,
MlStrings.ID_LENGTH_LIMIT));
MlStrings.ID_LENGTH_LIMIT), validationException);
}
List<String> badTags = tags.stream()
.filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false)
.collect(Collectors.toList());
if (badTags.isEmpty() == false) {
validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS,
badTags,
MlStrings.ID_LENGTH_LIMIT),
validationException);
}
checkIllegalSetting(version, VERSION.getPreferredName());
checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
for(String tag : tags) {
if (tag.equals(modelId)) {
validationException = addValidationError("none of the tags must equal the model_id", validationException);
break;
}
}
if (forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
validationException = checkIllegalSetting(estimatedHeapMemory,
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
validationException);
validationException = checkIllegalSetting(estimatedOperations,
ESTIMATED_OPERATIONS.getPreferredName(),
validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
}
private static void checkIllegalSetting(Object value, String setting) {
if (validationException != null) {
throw validationException;
}
return this;
}
private static ActionRequestValidationException checkIllegalSetting(Object value,
String setting,
ActionRequestValidationException validationException) {
if (value != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException);
}
return validationException;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
createdBy,
version,
createdBy == null ? "user" : createdBy,
version == null ? Version.CURRENT : version,
description,
createTime == null ? Instant.now() : createTime,
definition,
tags,
metadata,
input,
estimatedHeapMemory,
estimatedOperations,
licenseLevel);
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
estimatedOperations == null ? 0 : estimatedOperations,
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
}
}
@ -531,6 +598,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return new LazyModelDefinition(input.readString(), null);
}
private LazyModelDefinition(LazyModelDefinition definition) {
if (definition != null) {
this.compressedString = definition.compressedString;
this.parsedDefinition = definition.parsedDefinition;
}
}
private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
if (compressedString == null && trainedModelDefinition == null) {
throw new IllegalArgumentException("unexpected null model definition");

View File

@ -179,6 +179,12 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
this(true);
}
public Builder(TrainedModelDefinition definition) {
this(true);
this.preProcessors = new ArrayList<>(definition.getPreProcessors());
this.trainedModel = definition.trainedModel;
}
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
return this;

View File

@ -95,6 +95,10 @@ public final class Messages {
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " +
"hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters.";
public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";

View File

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
@Override
protected Request createTestInstance() {
String modelId = randomAlphaOfLength(10);
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build());
}
@Override
protected Writeable.Reader<Request> instanceReader() {
return (in) -> {
Request request = new Request(in);
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry());
return request;
};
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
String modelId = randomAlphaOfLength(10);
return new Response(TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build());
}
@Override
protected Writeable.Reader<Response> instanceReader() {
return (in) -> {
Response response = new Response(in);
response.getResponse().ensureParsedDefinition(xContentRegistry());
return response;
};
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
}
}

View File

@ -5,8 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
@ -56,14 +56,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
return TrainedModelConfig.builder()
.setInput(TrainedModelInputTests.createRandomInput())
.setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
.setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
.setCreateTime(Instant.ofEpochMilli(randomLongBetween(Instant.MIN.getEpochSecond(), Instant.MAX.getEpochSecond())))
.setVersion(Version.CURRENT)
.setModelId(modelId)
.setCreatedBy(randomAlphaOfLength(10))
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
.setEstimatedHeapMemory(randomNonNegativeLong())
.setEstimatedOperations(randomNonNegativeLong())
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(),
License.OperationMode.GOLD.description(),
License.OperationMode.BASIC.description()))
.setTags(tags);
}
@ -191,50 +193,52 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
}
public void testValidateWithNullDefinition() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder().validate());
assertThat(ex.getMessage(), containsString("[definition] must not be null."));
}
public void testValidateWithInvalidID() {
String modelId = "InvalidID-";
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
}
public void testValidateWithLongID() {
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
assertThat(ex.getMessage(),
containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
}
public void testValidateWithIllegallyUserProvidedFields() {
String modelId = "simplemodel";
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreateTime(Instant.now())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
.setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
ex = expectThrows(ElasticsearchException.class,
ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setVersion(Version.CURRENT)
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
.setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
ex = expectThrows(ElasticsearchException.class,
ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedBy("ml_user")
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
.setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
}
public void testSerializationWithLazyDefinition() throws IOException {

View File

@ -133,7 +133,6 @@ integTest.runner {
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
'ml/inference_crud/Test delete given used trained model',
'ml/inference_crud/Test delete given unused trained model',
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',

View File

@ -9,16 +9,20 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
@ -35,26 +39,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
@Before
public void createBothModels() throws Exception {
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId("test_classification")
.setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinitionDoc.docId("test_classification", 0))
.setSource(buildClassificationModelDoc(), XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId("test_regression")
.setSource(REGRESSION_CONFIG, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinitionDoc.docId("test_regression", 0))
.setSource(buildRegressionModelDoc(), XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
}
@After
public void deleteBothModels() {
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
}
public void testPipelineCreationAndDeletion() throws Exception {
@ -392,6 +384,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"definition\": " + REGRESSION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
@ -519,28 +512,27 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }\n" +
"}";
private static String buildClassificationModelDoc() throws IOException {
String compressed =
InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
return modelDocString(compressed, "test_classification");
private TrainedModelConfig buildClassificationModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
private static String buildRegressionModelDoc() throws IOException {
String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
return modelDocString(compressed, "test_regression");
private TrainedModelConfig buildRegressionModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(REGRESSION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
private static String modelDocString(String compressedDefinition, String modelId) {
return "" +
"{" +
"\"model_id\": \"" + modelId + "\",\n" +
"\"doc_num\": 0,\n" +
"\"doc_type\": \"trained_model_definition_doc\",\n" +
" \"compression_version\": " + 1 + ",\n" +
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
"\"definition\": \"" + compressedDefinition + "\"\n" +
"}";
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final String CLASSIFICATION_CONFIG = "" +
@ -549,8 +541,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"created_by\": \"es_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +

View File

@ -6,10 +6,18 @@
package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -18,26 +26,19 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.license.License;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.junit.After;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
@ -62,22 +63,8 @@ public class TrainedModelIT extends ESRestTestCase {
public void testGetTrainedModels() throws IOException {
String modelId = "a_test_regression_model";
String modelId2 = "a_test_regression_model-2";
Request model1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
Request model2 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
model2.setJsonEntity(buildRegressionModel(modelId2));
assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
putRegressionModel(modelId);
putRegressionModel(modelId2);
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));
@ -164,17 +151,7 @@ public class TrainedModelIT extends ESRestTestCase {
public void testDeleteTrainedModels() throws IOException {
String modelId = "test_delete_regression_model";
Request model1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
putRegressionModel(modelId);
Response delModel = client().performRequest(new Request("DELETE",
MachineLearning.BASE_PATH + "inference/" + modelId));
@ -208,41 +185,67 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(response, containsString("\"definition\""));
}
private static String buildRegressionModel(String modelId) throws IOException {
private void putRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Collections.emptyList())
.setTrainedModel(buildRegression());
TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
.setCreatedBy("ml_test")
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(0)
.build()
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
.build().toXContent(builder, ToXContent.EMPTY_PARAMS);
Request model = new Request("PUT", "_ml/inference/" + modelId);
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}
}
private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Collections.emptyList())
.setTrainedModel(LocalModelTests.buildRegression())
private static TrainedModel buildRegression() {
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5),
TreeNode.builder(1).setLeafValue(0.3),
TreeNode.builder(2)
.setThreshold(0.0)
.setSplitFeature(3)
.setLeftChild(3)
.setRightChild(4),
TreeNode.builder(3).setLeafValue(0.1),
TreeNode.builder(4).setLeafValue(0.2))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(2)
.setThreshold(1.0),
TreeNode.builder(1).setLeafValue(1.5),
TreeNode.builder(2).setLeafValue(0.9))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(0.2),
TreeNode.builder(1).setLeafValue(1.5),
TreeNode.builder(2).setLeafValue(0.9))
.build();
return Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5)))
.build();
String compressedString = InferenceToXContentCompressor.deflate(definition);
TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0)
.setCompressedString(compressedString)
.setTotalDefinitionLength(compressedString.length())
.setDefinitionLength(compressedString.length())
.setCompressionVersion(1)
.setModelId(modelId).build();
doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
}
}
@After
public void clearMlState() throws Exception {

View File

@ -112,6 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@ -183,6 +184,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportPutFilterAction;
import org.elasticsearch.xpack.ml.action.TransportPutJobAction;
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
@ -273,6 +275,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@ -773,7 +776,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new RestExplainDataFrameAnalyticsAction(restController),
new RestGetTrainedModelsAction(restController),
new RestDeleteTrainedModelAction(restController),
new RestGetTrainedModelsStatsAction(restController)
new RestGetTrainedModelsStatsAction(restController),
new RestPutTrainedModelAction(restController)
);
}
@ -844,7 +848,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class),
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class)
);
}

View File

@ -0,0 +1,187 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.io.IOException;
import java.time.Instant;
import java.util.List;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Request, Response> {
private final TrainedModelProvider trainedModelProvider;
private final XPackLicenseState licenseState;
private final NamedXContentRegistry xContentRegistry;
private final Client client;
@Inject
public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService,
ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, Client client,
TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new,
indexNameExpressionResolver);
this.licenseState = licenseState;
this.trainedModelProvider = trainedModelProvider;
this.xContentRegistry = xContentRegistry;
this.client = client;
}
@Override
protected String executor() {
return ThreadPool.Names.SAME;
}
@Override
protected Response read(StreamInput in) throws IOException {
return new Response(in);
}
@Override
protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
try {
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
ex,
request.getTrainedModelConfig().getModelId()));
return;
} catch (ElasticsearchException ex) {
listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.",
ex,
request.getTrainedModelConfig().getModelId()));
return;
}
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setCreatedBy("api_user")
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
.build();
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
listener::onFailure
)),
listener::onFailure
);
ActionListener<Void> modelIdTagCheckListener = ActionListener.wrap(
r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener),
listener::onFailure
);
checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener);
}
private void checkModelIdAgainstTags(String modelId, ActionListener<Void> listener) {
QueryBuilder builder = QueryBuilders.constantScoreQuery(
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId)));
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
if (response.getHits().getTotalHits().value > 0) {
listener.onFailure(
ExceptionsHelper.badRequestException(
Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId)));
return;
}
listener.onResponse(null);
},
listener::onFailure
),
client::search);
}
private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> listener) {
if (tags.isEmpty()) {
listener.onResponse(null);
return;
}
QueryBuilder builder = QueryBuilders.constantScoreQuery(
QueryBuilders.boolQuery()
.filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
if (response.getHits().getTotalHits().value > 0) {
listener.onFailure(
ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags)));
return;
}
listener.onResponse(null);
},
listener::onFailure
),
client::search);
}
@Override
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
}
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
if (licenseState.isMachineLearningAllowed()) {
super.doExecute(task, request, listener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}
}

View File

@ -175,10 +175,12 @@ public class TrainedModelProvider {
r -> {
assert r.getItems().length == 2;
if (r.getItems()[0].isFailed()) {
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model config for inference",
trainedModelConfig.getModelId()),
r.getItems()[0].getFailure().getCause());
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
return;
}

View File

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
public class RestPutTrainedModelAction extends BaseRestHandler {
public RestPutTrainedModelAction(RestController controller) {
controller.registerHandler(RestRequest.Method.PUT,
MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}",
this);
}
@Override
public String getName() {
return "xpack_ml_put_trained_model_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));
return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}

View File

@ -13,7 +13,6 @@ import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterState;
@ -37,26 +36,30 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.core.ml.client.MachineLearningClient;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
@ -561,12 +564,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"modelprocessorlicensetest\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" \"field_mappings\": {}\n" +
" }\n" +
" }]}\n";
// Creating a pipeline should work
@ -748,76 +746,22 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
private void putInferenceModel(String modelId) throws Exception {
String config = "" +
"{\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
" \"estimated_operations\": 0,\n" +
" \"created_time\": 0\n" +
"}";
String definition = "" +
"{" +
" \"trained_model\": {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }" +
"}";
String compressedDefinitionString =
InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8)));
String compressedDefinition = "" +
"{" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" +
" \"doc_num\": " + 0 + ",\n" +
" \"compression_version\": " + 1 + ",\n" +
" \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" +
" \"definition_length\": " + compressedDefinitionString.length() + ",\n" +
" \"definition\": \"" + compressedDefinitionString + "\"\n" +
"}";
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(modelId)
.setSource(config, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
.setSource(compressedDefinition, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
private void putInferenceModel(String modelId) {
TrainedModelConfig config = TrainedModelConfig.builder()
.setParsedDefinition(
new TrainedModelDefinition.Builder()
.setTrainedModel(
Tree.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(Arrays.asList("feature1"))
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
.build())
.setPreProcessors(Collections.emptyList()))
.setModelId(modelId)
.setDescription("test model for classification")
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
}
private static OperationMode randomInvalidLicenseType() {

View File

@ -200,7 +200,6 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)

View File

@ -0,0 +1,28 @@
{
"ml.put_trained_model":{
"documentation":{
"url":"TODO"
},
"stability":"experimental",
"url":{
"paths":[
{
"path":"/_ml/inference/{model_id}",
"methods":[
"PUT"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained models to store"
}
}
}
]
},
"body": {
"description":"The trained model configuration",
"required":true
}
}
}

View File

@ -1,3 +1,74 @@
setup:
- skip:
features: headers
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-0
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "regression"
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-1
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "regression"
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-classification-model
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "classification",
"classification_labels": ["no", "yes"]
}
}
}
}
---
"Test get given missing trained model":
@ -24,56 +95,52 @@
- match: { count: 0 }
- match: { trained_model_configs: [] }
---
"Test get models":
- do:
ml.get_trained_models:
model_id: "*"
- match: { count: 4 }
- match: { trained_model_configs.0.model_id: "a-classification-model" }
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
- match: { trained_model_configs.2.model_id: "a-regression-model-1" }
- do:
ml.get_trained_models:
model_id: "a-regression*"
- match: { count: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
- do:
ml.get_trained_models:
model_id: "*"
from: 0
size: 2
- match: { count: 4 }
- match: { trained_model_configs.0.model_id: "a-classification-model" }
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
- do:
ml.get_trained_models:
model_id: "*"
from: 1
size: 1
- match: { count: 4 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
---
"Test delete given unused trained model":
- do:
index:
id: trained_model_config-unused-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "unused-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local"
}
- do:
indices.refresh: {}
- do:
ml.delete_trained_model:
model_id: "unused-regression-model"
model_id: "a-classification-model"
- match: { acknowledged: true }
---
"Test delete with missing model":
- do:
catch: missing
ml.delete_trained_model:
model_id: "missing-trained-model"
---
"Test delete given used trained model":
- do:
index:
id: trained_model_config-used-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "used-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local"
}
- do:
indices.refresh: {}
- do:
ingest.put_pipeline:
id: "regression-model-pipeline"
@ -82,7 +149,7 @@
"processors": [
{
"inference" : {
"model_id" : "used-regression-model",
"model_id" : "a-regression-model-0",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
@ -95,12 +162,12 @@
- do:
catch: conflict
ml.delete_trained_model:
model_id: "used-regression-model"
model_id: "a-regression-model-0"
---
"Test get pre-packaged trained models":
- do:
ml.get_trained_models:
model_id: "_all"
model_id: "lang_ident_model_1"
allow_no_match: false
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }