diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
index 4967d8091c9..2e077f547e3 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
@@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -792,6 +793,16 @@ final class MLRequestConverters {
return new Request(HttpDelete.METHOD_NAME, endpoint);
}
+ static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
+ String endpoint = new EndpointBuilder()
+ .addPathPartAsIs("_ml", "inference")
+ .addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
+ .build();
+ Request request = new Request(HttpPut.METHOD_NAME, endpoint);
+ request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
+ return request;
+ }
+
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
index 0a71b8ddb01..bdb2f22f3b3 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
@@ -100,6 +100,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -2340,6 +2342,48 @@ public final class MachineLearningClient {
Collections.emptySet());
}
+ /**
+ * Put trained model config
+ *
+ * For additional info
+ * see
+ * PUT Trained Model Config documentation
+ *
+ * @param request The {@link PutTrainedModelRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @return {@link PutTrainedModelResponse} response object
+ */
+ public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
+ return restHighLevelClient.performRequestAndParseEntity(request,
+ MLRequestConverters::putTrainedModel,
+ options,
+ PutTrainedModelResponse::fromXContent,
+ Collections.emptySet());
+ }
+
+ /**
+ * Put trained model config asynchronously and notifies listener upon completion
+ *
+ * For additional info
+ * see
+ * PUT Trained Model Config documentation
+ *
+ * @param request The {@link PutTrainedModelRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @param listener Listener to be notified upon request completion
+ * @return cancellable that may be used to cancel the request
+ */
+ public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
+ RequestOptions options,
+ ActionListener listener) {
+ return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+ MLRequestConverters::putTrainedModel,
+ options,
+ PutTrainedModelResponse::fromXContent,
+ listener,
+ Collections.emptySet());
+ }
+
/**
* Gets trained model stats
*
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
new file mode 100644
index 00000000000..780ec31771b
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelRequest implements Validatable, ToXContentObject {
+
+ private final TrainedModelConfig config;
+
+ public PutTrainedModelRequest(TrainedModelConfig config) {
+ this.config = config;
+ }
+
+ public TrainedModelConfig getTrainedModelConfig() {
+ return config;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+ return config.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ PutTrainedModelRequest request = (PutTrainedModelRequest) o;
+ return Objects.equals(config, request.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(config);
+ }
+
+ @Override
+ public final String toString() {
+ return Strings.toString(config);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
new file mode 100644
index 00000000000..3bc81f18129
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelResponse implements ToXContentObject {
+
+ private final TrainedModelConfig trainedModelConfig;
+
+ public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
+ return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
+ }
+
+ public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
+ this.trainedModelConfig = trainedModelConfig;
+ }
+
+ public TrainedModelConfig getResponse() {
+ return trainedModelConfig;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ return trainedModelConfig.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ PutTrainedModelResponse response = (PutTrainedModelResponse) o;
+ return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelConfig);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
new file mode 100644
index 00000000000..9bec4c4eb5d
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.CheckedFunction;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.Streams;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+/**
+ * Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
+ */
+public final class InferenceToXContentCompressor {
+ private static final int BUFFER_SIZE = 4096;
+ private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
+
+ private InferenceToXContentCompressor() {}
+
+ public static String deflate(T objectToCompress) throws IOException {
+ BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
+ return deflate(reference);
+ }
+
+ public static T inflate(String compressedString,
+ CheckedFunction parserFunction,
+ NamedXContentRegistry xContentRegistry) throws IOException {
+ try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ inflate(compressedString, MAX_INFLATED_BYTES),
+ XContentType.JSON)) {
+ return parserFunction.apply(parser);
+ }
+ }
+
+ static BytesReference inflate(String compressedString, long streamSize) throws IOException {
+ byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
+ InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
+ InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
+ return Streams.readFully(inflateStream);
+ }
+
+ private static String deflate(BytesReference reference) throws IOException {
+ BytesStreamOutput out = new BytesStreamOutput();
+ try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
+ reference.writeTo(compressedOutput);
+ }
+ return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
new file mode 100644
index 00000000000..683e23dc9d7
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Objects;
+
+/**
+ * This is a pared down bounded input stream.
+ * Only read is specifically enforced.
+ */
+final class SimpleBoundedInputStream extends InputStream {
+
+ private final InputStream in;
+ private final long maxBytes;
+ private long numBytes;
+
+ SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
+ this.in = Objects.requireNonNull(inputStream, "inputStream");
+ if (maxBytes < 0) {
+ throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
+ }
+ this.maxBytes = maxBytes;
+ }
+
+
+ /**
+ * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
+ * @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
+ * @throws IOException on failure
+ */
+ @Override
+ public int read() throws IOException {
+ // We have reached the maximum, signal stream completion.
+ if (numBytes >= maxBytes) {
+ return -1;
+ }
+ numBytes++;
+ return in.read();
+ }
+
+ /**
+ * Delegates `close` to the wrapped InputStream
+ * @throws IOException on failure
+ */
+ @Override
+ public void close() throws IOException {
+ in.close();
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
index 23eb01fb3b1..9d2b323cf48 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
@@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.time.Instant;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -111,7 +112,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
- this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
+ this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
this.definition = definition;
this.compressedDefinition = compressedDefinition;
this.description = description;
@@ -293,12 +294,12 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- public Builder setCreatedBy(String createdBy) {
+ private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
- public Builder setVersion(Version version) {
+ private Builder setVersion(Version version) {
this.version = version;
return this;
}
@@ -312,7 +313,7 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- public Builder setCreateTime(Instant createTime) {
+ private Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
@@ -322,6 +323,10 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
+ public Builder setTags(String... tags) {
+ return setTags(Arrays.asList(tags));
+ }
+
public Builder setMetadata(Map metadata) {
this.metadata = metadata;
return this;
@@ -347,17 +352,17 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
+ private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
- public Builder setEstimatedOperations(Long estimatedOperations) {
+ private Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
- public Builder setLicenseLevel(String licenseLevel) {
+ private Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
index 10f849cac48..9b19323023d 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
@@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
+import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@@ -48,6 +49,10 @@ public class TrainedModelInput implements ToXContentObject {
this.fieldNames = fieldNames;
}
+ public TrainedModelInput(String... fieldNames) {
+ this(Arrays.asList(fieldNames));
+ }
+
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
index 24825dfc265..f5733ef3a0d 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
@@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -91,6 +92,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.Detector;
import org.elasticsearch.client.ml.job.config.Job;
@@ -874,6 +878,20 @@ public class MLRequestConvertersTests extends ESTestCase {
assertNull(request.getEntity());
}
+ public void testPutTrainedModel() throws IOException {
+ TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig();
+ PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig);
+
+ Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest);
+
+ assertEquals(HttpPut.METHOD_NAME, request.getMethod());
+ assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId()));
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
+ TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build();
+ assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig));
+ }
+ }
+
public void testPutFilter() throws IOException {
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
@@ -1046,6 +1064,7 @@ public class MLRequestConvertersTests extends ESTestCase {
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+ namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
index 170fd4e858a..247b726e008 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
@@ -101,6 +101,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -146,9 +148,12 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
@@ -162,14 +167,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter;
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -178,11 +181,9 @@ import org.elasticsearch.search.SearchHit;
import org.junit.After;
import java.io.IOException;
-import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -190,7 +191,6 @@ import java.util.Locale;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
@@ -2222,6 +2222,50 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
}
}
+ public void testPutTrainedModel() throws Exception {
+ String modelId = "test-put-trained-model";
+ String modelIdCompressed = "test-put-trained-model-compressed-definition";
+
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+ machineLearningClient::putTrainedModel,
+ machineLearningClient::putTrainedModelAsync);
+ TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
+ assertThat(createdModel.getModelId(), equalTo(modelId));
+
+ definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ trainedModelConfig = TrainedModelConfig.builder()
+ .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
+ .setModelId(modelIdCompressed)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+ machineLearningClient::putTrainedModel,
+ machineLearningClient::putTrainedModelAsync);
+ createdModel = putTrainedModelResponse.getResponse();
+ assertThat(createdModel.getModelId(), equalTo(modelIdCompressed));
+
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true),
+ machineLearningClient::getTrainedModels,
+ machineLearningClient::getTrainedModelsAsync);
+
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed));
+ }
+
public void testGetTrainedModelsStats() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String modelIdPrefix = "a-get-trained-model-stats-";
@@ -2504,56 +2548,13 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
private void putTrainedModel(String modelId) throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelConfigString(modelId), XContentType.JSON)
- .id(modelId),
- RequestOptions.DEFAULT);
-
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
- .id("trained_model_definition_doc-" + modelId + "-0"),
- RequestOptions.DEFAULT);
- }
-
- private String compressDefinition(TrainedModelDefinition definition) throws IOException {
- BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
- BytesStreamOutput out = new BytesStreamOutput();
- try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
- reference.writeTo(compressedOutput);
- }
- return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
- }
-
- private static String modelConfigString(String modelId) {
- return "{\n" +
- " \"doc_type\": \"trained_model_config\",\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model\",\n" +
- " \"version\": \"7.6.0\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"ml_test\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0," +
- " \"estimated_operations\": 0," +
- " \"created_time\": 0\n" +
- "}";
- }
-
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private void waitForJobToClose(String jobId) throws Exception {
@@ -2798,4 +2799,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT);
assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false));
}
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 94863cbd5e6..860fe533fd3 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
import org.elasticsearch.client.ml.job.config.DataDescription;
@@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId;
import org.junit.After;
import java.io.IOException;
-import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
-import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
@@ -216,7 +219,6 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
@@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
}
+ public void testPutTrainedModel() throws Exception {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ // tag::put-trained-model-config
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition) // <1>
+ .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
+ .setModelId("my-new-trained-model") // <3>
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
+ .setDescription("test model") // <5>
+ .setMetadata(new HashMap<>()) // <6>
+ .setTags("my_regression_models") // <7>
+ .build();
+ // end::put-trained-model-config
+
+ trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId("my-new-trained-model")
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
+ .setDescription("test model")
+ .setMetadata(new HashMap<>())
+ .setTags("my_regression_models")
+ .build();
+
+ RestHighLevelClient client = highLevelClient();
+ {
+ // tag::put-trained-model-request
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
+ // end::put-trained-model-request
+
+ // tag::put-trained-model-execute
+ PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
+ // end::put-trained-model-execute
+
+ // tag::put-trained-model-response
+ TrainedModelConfig model = response.getResponse();
+ // end::put-trained-model-response
+
+ assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
+ highLevelClient().machineLearning()
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+ }
+ {
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
+
+ // tag::put-trained-model-execute-listener
+ ActionListener listener = new ActionListener() {
+ @Override
+ public void onResponse(PutTrainedModelResponse response) {
+ // <1>
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ // <2>
+ }
+ };
+ // end::put-trained-model-execute-listener
+
+ // Replace the empty listener by a blocking listener in test
+ CountDownLatch latch = new CountDownLatch(1);
+ listener = new LatchedActionListener<>(listener, latch);
+
+ // tag::put-trained-model-execute-async
+ client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
+ // end::put-trained-model-execute-async
+
+ assertTrue(latch.await(30L, TimeUnit.SECONDS));
+
+ highLevelClient().machineLearning()
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+ }
+ }
+
public void testGetTrainedModelsStats() throws Exception {
putTrainedModel("my-trained-model");
RestHighLevelClient client = highLevelClient();
@@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
private void putTrainedModel(String modelId) throws IOException {
- TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelConfigString(modelId), XContentType.JSON)
- .id(modelId),
- RequestOptions.DEFAULT);
-
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
- .id("trained_model_definition_doc-" + modelId + "-0"),
- RequestOptions.DEFAULT);
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
- private String compressDefinition(TrainedModelDefinition definition) throws IOException {
- BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
- BytesStreamOutput out = new BytesStreamOutput();
- try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
- reference.writeTo(compressedOutput);
- }
- return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
- }
-
- private static String modelConfigString(String modelId) {
- return "{\n" +
- " \"doc_type\": \"trained_model_config\",\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model for\",\n" +
- " \"version\": \"7.6.0\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"ml_test\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0," +
- " \"estimated_operations\": 0," +
- " \"created_time\": 0\n" +
- "}";
- }
-
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
new file mode 100644
index 00000000000..b3956c5c6af
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase {
+
+ @Override
+ protected PutTrainedModelRequest createTestInstance() {
+ return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig());
+ }
+
+ @Override
+ protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException {
+ return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build());
+ }
+
+ @Override
+ protected boolean supportsUnknownFields() {
+ return false;
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
new file mode 100644
index 00000000000..61e1638547b
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase {
+
+ @Override
+ protected PutTrainedModelResponse createTestInstance() {
+ return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig());
+ }
+
+ @Override
+ protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException {
+ return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build());
+ }
+
+ @Override
+ protected boolean supportsUnknownFields() {
+ return false;
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
new file mode 100644
index 00000000000..11747638a2c
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class InferenceToXContentCompressorTests extends ESTestCase {
+
+ public void testInflateAndDeflate() throws IOException {
+ for(int i = 0; i < 10; i++) {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+ TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
+ parser -> TrainedModelDefinition.fromXContent(parser).build(),
+ xContentRegistry());
+
+ // Did we inflate to the same object?
+ assertThat(inflatedDefinition, equalTo(definition));
+ }
+ }
+
+ public void testInflateTooLargeStream() throws IOException {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+ BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
+ assertThat(inflatedBytes.length(), equalTo(10));
+ try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ LoggingDeprecationHandler.INSTANCE,
+ inflatedBytes,
+ XContentType.JSON)) {
+ expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser));
+ }
+ }
+
+ public void testInflateGarbage() {
+ expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
index 51f9692b8b6..81c64b3a22c 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
@@ -37,6 +37,24 @@ import java.util.stream.Stream;
public class TrainedModelConfigTests extends AbstractXContentTestCase {
+ public static TrainedModelConfig createTestTrainedModelConfig() {
+ return new TrainedModelConfig(
+ randomAlphaOfLength(10),
+ randomAlphaOfLength(10),
+ Version.CURRENT,
+ randomBoolean() ? null : randomAlphaOfLength(100),
+ Instant.ofEpochMilli(randomNonNegativeLong()),
+ randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
+ randomBoolean() ? null : randomAlphaOfLength(100),
+ randomBoolean() ? null :
+ Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
+ randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
+ randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
+ randomBoolean() ? null : randomNonNegativeLong(),
+ randomBoolean() ? null : randomNonNegativeLong(),
+ randomBoolean() ? null : randomFrom("platinum", "basic"));
+ }
+
@Override
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
return TrainedModelConfig.fromXContent(parser);
@@ -54,22 +72,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
- randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
- randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
- randomBoolean() ? null : randomNonNegativeLong(),
- randomBoolean() ? null : randomNonNegativeLong(),
- randomBoolean() ? null : randomFrom("platinum", "basic"));
-
+ return createTestTrainedModelConfig();
}
@Override
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
index 72a640d609b..5c857bb625b 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
@@ -67,15 +67,17 @@ public class EnsembleTests extends AbstractXContentTestCase {
.collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
- .limit(numberOfFeatures)
+ .limit(numberOfModels)
.collect(Collectors.toList());
- OutputAggregator outputAggregator = null;
- if (randomBoolean()) {
- List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
- outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights));
+ List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
+ List possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
+ new LogisticRegression(weights)));
+ if (targetType.equals(TargetType.REGRESSION)) {
+ possibleAggregators.add(new WeightedSum(weights));
}
+ OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
List categoryLabels = null;
- if (randomBoolean()) {
+ if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return new Ensemble(featureNames,
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
index 748dc982e67..5e105dffd8b 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
@@ -84,7 +84,7 @@ public class TreeTests extends AbstractXContentTestCase {
childNodes = nextNodes;
}
List categoryLabels = null;
- if (randomBoolean()) {
+ if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.setClassificationLabels(categoryLabels)
diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
new file mode 100644
index 00000000000..dadc8dcf65a
--- /dev/null
+++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
@@ -0,0 +1,53 @@
+--
+:api: put-trained-model
+:request: PutTrainedModelRequest
+:response: PutTrainedModelResponse
+--
+[role="xpack"]
+[id="{upid}-{api}"]
+=== Put Trained Model API
+
+Creates a new trained model for inference.
+The API accepts a +{request}+ object as a request and returns a +{response}+.
+
+[id="{upid}-{api}-request"]
+==== Put Trained Model request
+
+A +{request}+ requires the following argument:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-request]
+--------------------------------------------------
+<1> The configuration of the {infer} Trained Model to create
+
+[id="{upid}-{api}-config"]
+==== Trained Model configuration
+
+The `TrainedModelConfig` object contains all the details about the trained model
+configuration and contains the following arguments:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-config]
+--------------------------------------------------
+<1> The {infer} definition for the model
+<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport.
+ Do not supply both the compressed and uncompressed definitions.
+<3> The unique model id
+<4> The input field names for the model definition
+<5> Optionally, a human-readable description
+<6> Optionally, an object map contain metadata about the model
+<7> Optionally, an array of tags to organize the model
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Response
+
+The returned +{response}+ contains the newly created trained model.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------
diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc
index e0d228b5d1e..4b848819702 100644
--- a/docs/java-rest/high-level/supported-apis.asciidoc
+++ b/docs/java-rest/high-level/supported-apis.asciidoc
@@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
* <<{upid}-evaluate-data-frame>>
* <<{upid}-explain-data-frame-analytics>>
* <<{upid}-get-trained-models>>
+* <<{upid}-put-trained-model>>
* <<{upid}-get-trained-models-stats>>
* <<{upid}-delete-trained-model>>
* <<{upid}-put-filter>>
@@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
include::ml/evaluate-data-frame.asciidoc[]
include::ml/explain-data-frame-analytics.asciidoc[]
include::ml/get-trained-models.asciidoc[]
+include::ml/put-trained-model.asciidoc[]
include::ml/get-trained-models-stats.asciidoc[]
include::ml/delete-trained-model.asciidoc[]
include::ml/put-filter.asciidoc[]
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
index ca0227e84ae..843e3e611df 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
@@ -126,6 +126,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@@ -381,6 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,
+ PutTrainedModelAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
new file mode 100644
index 00000000000..06fbb6401a0
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelAction extends ActionType {
+
+ public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
+ public static final String NAME = "cluster:monitor/xpack/ml/inference/put";
+ private PutTrainedModelAction() {
+ super(NAME, Response::new);
+ }
+
+ public static class Request extends AcknowledgedRequest {
+
+ public static Request parseRequest(String modelId, XContentParser parser) {
+ TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
+
+ if (builder.getModelId() == null) {
+ builder.setModelId(modelId).build();
+ } else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) {
+ // If we have model_id in both URI and body, they must be identical
+ throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
+ builder.getModelId(),
+ modelId));
+ }
+ // Validations are done against the builder so we can build the full config object.
+ // This allows us to not worry about serializing a builder class between nodes.
+ return new Request(builder.validate(true).build());
+ }
+
+ private final TrainedModelConfig config;
+
+ public Request(TrainedModelConfig config) {
+ this.config = config;
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.config = new TrainedModelConfig(in);
+ }
+
+ public TrainedModelConfig getTrainedModelConfig() {
+ return config;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ config.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(config, request.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(config);
+ }
+
+ @Override
+ public final String toString() {
+ return Strings.toString(config);
+ }
+ }
+
+ public static class Response extends ActionResponse implements ToXContentObject {
+
+ private final TrainedModelConfig trainedModelConfig;
+
+ public Response(TrainedModelConfig trainedModelConfig) {
+ this.trainedModelConfig = trainedModelConfig;
+ }
+
+ public Response(StreamInput in) throws IOException {
+ super(in);
+ trainedModelConfig = new TrainedModelConfig(in);
+ }
+
+ public TrainedModelConfig getResponse() {
+ return trainedModelConfig;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ trainedModelConfig.writeTo(out);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ return trainedModelConfig.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Response response = (Response) o;
+ return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelConfig);
+ }
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
index be4d40efc85..95589ac8b61 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
@@ -34,6 +35,9 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
public class TrainedModelConfig implements ToXContentObject, Writeable {
@@ -352,13 +356,31 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Long estimatedHeapMemory;
private Long estimatedOperations;
private LazyModelDefinition definition;
- private String licenseLevel = License.OperationMode.PLATINUM.description();
+ private String licenseLevel;
+
+ public Builder() {}
+
+ public Builder(TrainedModelConfig config) {
+ this.modelId = config.getModelId();
+ this.createdBy = config.getCreatedBy();
+ this.version = config.getVersion();
+ this.createTime = config.getCreateTime();
+ this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
+ this.description = config.getDescription();
+ this.tags = config.getTags();
+ this.metadata = config.getMetadata();
+ this.input = config.getInput();
+ }
public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}
+ public String getModelId() {
+ return this.modelId;
+ }
+
public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
@@ -466,51 +488,96 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
- // TODO move to REST level instead of here in the builder
- public void validate() {
- // We require a definition to be available here even though it will be stored in a different doc
- ExceptionsHelper.requireNonNull(definition, DEFINITION);
- ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
-
- if (MlStrings.isValidId(modelId) == false) {
- throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
- }
-
- if (MlStrings.hasValidLengthForId(modelId) == false) {
- throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
- MODEL_ID.getPreferredName(),
- modelId,
- MlStrings.ID_LENGTH_LIMIT));
- }
-
- checkIllegalSetting(version, VERSION.getPreferredName());
- checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
- checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
- checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
- checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
- checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
+ public Builder validate() {
+ return validate(false);
}
- private static void checkIllegalSetting(Object value, String setting) {
- if (value != null) {
- throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
+ /**
+ * Runs validations against the builder.
+ * @return The current builder object if validations are successful
+ * @throws ActionRequestValidationException when there are validation failures.
+ */
+ public Builder validate(boolean forCreation) {
+ // We require a definition to be available here even though it will be stored in a different doc
+ ActionRequestValidationException validationException = null;
+ if (definition == null) {
+ validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException);
}
+ if (modelId == null) {
+ validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
+ }
+
+ if (modelId != null && MlStrings.isValidId(modelId) == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
+ modelId),
+ validationException);
+ }
+ if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
+ modelId,
+ MlStrings.ID_LENGTH_LIMIT), validationException);
+ }
+ List badTags = tags.stream()
+ .filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false)
+ .collect(Collectors.toList());
+ if (badTags.isEmpty() == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS,
+ badTags,
+ MlStrings.ID_LENGTH_LIMIT),
+ validationException);
+ }
+
+ for(String tag : tags) {
+ if (tag.equals(modelId)) {
+ validationException = addValidationError("none of the tags must equal the model_id", validationException);
+ break;
+ }
+ }
+ if (forCreation) {
+ validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(estimatedHeapMemory,
+ ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
+ validationException);
+ validationException = checkIllegalSetting(estimatedOperations,
+ ESTIMATED_OPERATIONS.getPreferredName(),
+ validationException);
+ validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
+ }
+
+ if (validationException != null) {
+ throw validationException;
+ }
+
+ return this;
+ }
+
+ private static ActionRequestValidationException checkIllegalSetting(Object value,
+ String setting,
+ ActionRequestValidationException validationException) {
+ if (value != null) {
+ return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException);
+ }
+ return validationException;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
- createdBy,
- version,
+ createdBy == null ? "user" : createdBy,
+ version == null ? Version.CURRENT : version,
description,
createTime == null ? Instant.now() : createTime,
definition,
tags,
metadata,
input,
- estimatedHeapMemory,
- estimatedOperations,
- licenseLevel);
+ estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
+ estimatedOperations == null ? 0 : estimatedOperations,
+ licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
}
}
@@ -531,6 +598,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return new LazyModelDefinition(input.readString(), null);
}
+ private LazyModelDefinition(LazyModelDefinition definition) {
+ if (definition != null) {
+ this.compressedString = definition.compressedString;
+ this.parsedDefinition = definition.parsedDefinition;
+ }
+ }
+
private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
if (compressedString == null && trainedModelDefinition == null) {
throw new IllegalArgumentException("unexpected null model definition");
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
index e176d9a2885..cf7a8b7d224 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
@@ -179,6 +179,12 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
this(true);
}
+ public Builder(TrainedModelDefinition definition) {
+ this(true);
+ this.preProcessors = new ArrayList<>(definition.getPreProcessors());
+ this.trainedModel = definition.trainedModel;
+ }
+
public Builder setPreProcessors(List preProcessors) {
this.preProcessors = preProcessors;
return this;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
index 1b80d196359..ef0fcd4fdb1 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
@@ -95,6 +95,10 @@ public final class Messages {
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
+ public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " +
+ "hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters.";
+ public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
+ public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
new file mode 100644
index 00000000000..0c39469c902
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ String modelId = randomAlphaOfLength(10);
+ return new Request(TrainedModelConfigTests.createTestInstance(modelId)
+ .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+ .build());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return (in) -> {
+ Request request = new Request(in);
+ request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry());
+ return request;
+ };
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
new file mode 100644
index 00000000000..5813b13c8ad
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Response createTestInstance() {
+ String modelId = randomAlphaOfLength(10);
+ return new Response(TrainedModelConfigTests.createTestInstance(modelId)
+ .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+ .build());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return (in) -> {
+ Response response = new Response(in);
+ response.getResponse().ensureParsedDefinition(xContentRegistry());
+ return response;
+ };
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
index 61a3b960f56..81011262012 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
@@ -5,8 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
-import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
@@ -56,14 +56,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase TrainedModelConfig.builder().validate());
- assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
+ () -> TrainedModelConfig.builder().validate());
+ assertThat(ex.getMessage(), containsString("[definition] must not be null."));
}
public void testValidateWithInvalidID() {
String modelId = "InvalidID-";
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
+ assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
}
public void testValidateWithLongID() {
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
+ assertThat(ex.getMessage(),
+ containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
}
public void testValidateWithIllegallyUserProvidedFields() {
String modelId = "simplemodel";
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreateTime(Instant.now())
- .setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
+ .setModelId(modelId).validate(true));
+ assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
- ex = expectThrows(ElasticsearchException.class,
+ ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setVersion(Version.CURRENT)
- .setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
+ .setModelId(modelId).validate(true));
+ assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
- ex = expectThrows(ElasticsearchException.class,
+ ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedBy("ml_user")
- .setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
+ .setModelId(modelId).validate(true));
+ assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
}
public void testSerializationWithLazyDefinition() throws IOException {
diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle
index 8c9ba6df7f0..7e19a4d606d 100644
--- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle
+++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle
@@ -133,7 +133,6 @@ integTest.runner {
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
'ml/inference_crud/Test delete given used trained model',
- 'ml/inference_crud/Test delete given unused trained model',
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
index 20aea0f6316..545ca247a4a 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
@@ -9,16 +9,20 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
-import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.junit.After;
import org.junit.Before;
import java.io.IOException;
@@ -35,26 +39,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
@Before
public void createBothModels() throws Exception {
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId("test_classification")
- .setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId(TrainedModelDefinitionDoc.docId("test_classification", 0))
- .setSource(buildClassificationModelDoc(), XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId("test_regression")
- .setSource(REGRESSION_CONFIG, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId(TrainedModelDefinitionDoc.docId("test_regression", 0))
- .setSource(buildRegressionModelDoc(), XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
+ }
+
+ @After
+ public void deleteBothModels() {
+ client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
+ client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
}
public void testPipelineCreationAndDeletion() throws Exception {
@@ -392,6 +384,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" +
+ " \"definition\": " + REGRESSION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
@@ -519,28 +512,27 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }\n" +
"}";
- private static String buildClassificationModelDoc() throws IOException {
- String compressed =
- InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
- return modelDocString(compressed, "test_classification");
+ private TrainedModelConfig buildClassificationModel() throws IOException {
+ try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ new BytesArray(CLASSIFICATION_CONFIG),
+ XContentType.JSON)) {
+ return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+ }
}
- private static String buildRegressionModelDoc() throws IOException {
- String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
- return modelDocString(compressed, "test_regression");
+ private TrainedModelConfig buildRegressionModel() throws IOException {
+ try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ new BytesArray(REGRESSION_CONFIG),
+ XContentType.JSON)) {
+ return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+ }
}
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final String CLASSIFICATION_CONFIG = "" +
@@ -549,8 +541,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
+ " \"definition\": " + CLASSIFICATION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"benwtrent\",\n" +
+ " \"created_by\": \"es_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
index 72c677d1aa4..0aec6bc3374 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
@@ -6,10 +6,18 @@
package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils;
-import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -18,26 +26,19 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
-import org.elasticsearch.license.License;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
-import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
-import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.junit.After;
import java.io.IOException;
-import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
@@ -62,22 +63,8 @@ public class TrainedModelIT extends ESRestTestCase {
public void testGetTrainedModels() throws IOException {
String modelId = "a_test_regression_model";
String modelId2 = "a_test_regression_model-2";
- Request model1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
- model1.setJsonEntity(buildRegressionModel(modelId));
- assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request modelDefinition1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
- modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
- assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request model2 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
- model2.setJsonEntity(buildRegressionModel(modelId2));
- assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
-
- adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+ putRegressionModel(modelId);
+ putRegressionModel(modelId2);
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));
@@ -164,17 +151,7 @@ public class TrainedModelIT extends ESRestTestCase {
public void testDeleteTrainedModels() throws IOException {
String modelId = "test_delete_regression_model";
- Request model1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
- model1.setJsonEntity(buildRegressionModel(modelId));
- assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request modelDefinition1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
- modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
- assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
- adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+ putRegressionModel(modelId);
Response delModel = client().performRequest(new Request("DELETE",
MachineLearning.BASE_PATH + "inference/" + modelId));
@@ -208,42 +185,68 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(response, containsString("\"definition\""));
}
- private static String buildRegressionModel(String modelId) throws IOException {
+ private void putRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
+ TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
+ .setPreProcessors(Collections.emptyList())
+ .setTrainedModel(buildRegression());
TrainedModelConfig.builder()
+ .setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
- .setCreatedBy("ml_test")
- .setVersion(Version.CURRENT)
- .setCreateTime(Instant.now())
- .setEstimatedOperations(0)
- .setLicenseLevel(License.OperationMode.PLATINUM.description())
- .setEstimatedHeapMemory(0)
- .build()
- .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
- return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
+ .build().toXContent(builder, ToXContent.EMPTY_PARAMS);
+ Request model = new Request("PUT", "_ml/inference/" + modelId);
+ model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
+ assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}
}
- private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException {
- try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
- TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
- .setPreProcessors(Collections.emptyList())
- .setTrainedModel(LocalModelTests.buildRegression())
- .build();
- String compressedString = InferenceToXContentCompressor.deflate(definition);
- TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0)
- .setCompressedString(compressedString)
- .setTotalDefinitionLength(compressedString.length())
- .setDefinitionLength(compressedString.length())
- .setCompressionVersion(1)
- .setModelId(modelId).build();
- doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
- return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
- }
+ private static TrainedModel buildRegression() {
+ List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
+ Tree tree1 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(0)
+ .setThreshold(0.5),
+ TreeNode.builder(1).setLeafValue(0.3),
+ TreeNode.builder(2)
+ .setThreshold(0.0)
+ .setSplitFeature(3)
+ .setLeftChild(3)
+ .setRightChild(4),
+ TreeNode.builder(3).setLeafValue(0.1),
+ TreeNode.builder(4).setLeafValue(0.2))
+ .build();
+ Tree tree2 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(2)
+ .setThreshold(1.0),
+ TreeNode.builder(1).setLeafValue(1.5),
+ TreeNode.builder(2).setLeafValue(0.9))
+ .build();
+ Tree tree3 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(1)
+ .setThreshold(0.2),
+ TreeNode.builder(1).setLeafValue(1.5),
+ TreeNode.builder(2).setLeafValue(0.9))
+ .build();
+ return Ensemble.builder()
+ .setTargetType(TargetType.REGRESSION)
+ .setFeatureNames(featureNames)
+ .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+ .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5)))
+ .build();
}
-
@After
public void clearMlState() throws Exception {
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
index 1a12dad3114..1ce8566cf8c 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
@@ -112,6 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@@ -183,6 +184,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportPutFilterAction;
import org.elasticsearch.xpack.ml.action.TransportPutJobAction;
+import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
@@ -273,6 +275,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@@ -773,7 +776,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new RestExplainDataFrameAnalyticsAction(restController),
new RestGetTrainedModelsAction(restController),
new RestDeleteTrainedModelAction(restController),
- new RestGetTrainedModelsStatsAction(restController)
+ new RestGetTrainedModelsStatsAction(restController),
+ new RestPutTrainedModelAction(restController)
);
}
@@ -844,7 +848,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class),
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
- new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
+ new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
+ new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class)
);
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
new file mode 100644
index 00000000000..a520f621672
--- /dev/null
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
@@ -0,0 +1,187 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.action;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.master.TransportMasterNodeAction;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.block.ClusterBlockException;
+import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.XPackField;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.List;
+
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+
+public class TransportPutTrainedModelAction extends TransportMasterNodeAction {
+
+ private final TrainedModelProvider trainedModelProvider;
+ private final XPackLicenseState licenseState;
+ private final NamedXContentRegistry xContentRegistry;
+ private final Client client;
+
+ @Inject
+ public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService,
+ ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
+ IndexNameExpressionResolver indexNameExpressionResolver, Client client,
+ TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
+ super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new,
+ indexNameExpressionResolver);
+ this.licenseState = licenseState;
+ this.trainedModelProvider = trainedModelProvider;
+ this.xContentRegistry = xContentRegistry;
+ this.client = client;
+ }
+
+ @Override
+ protected String executor() {
+ return ThreadPool.Names.SAME;
+ }
+
+ @Override
+ protected Response read(StreamInput in) throws IOException {
+ return new Response(in);
+ }
+
+ @Override
+ protected void masterOperation(Request request, ClusterState state, ActionListener listener) {
+ try {
+ request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
+ request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
+ } catch (IOException ex) {
+ listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
+ ex,
+ request.getTrainedModelConfig().getModelId()));
+ return;
+ } catch (ElasticsearchException ex) {
+ listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.",
+ ex,
+ request.getTrainedModelConfig().getModelId()));
+ return;
+ }
+
+ TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
+ .setVersion(Version.CURRENT)
+ .setCreateTime(Instant.now())
+ .setCreatedBy("api_user")
+ .setLicenseLevel(License.OperationMode.PLATINUM.description())
+ .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
+ .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
+ .build();
+
+ ActionListener tagsModelIdCheckListener = ActionListener.wrap(
+ r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
+ storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
+ listener::onFailure
+ )),
+ listener::onFailure
+ );
+
+ ActionListener modelIdTagCheckListener = ActionListener.wrap(
+ r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener),
+ listener::onFailure
+ );
+
+ checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener);
+ }
+
+ private void checkModelIdAgainstTags(String modelId, ActionListener listener) {
+ QueryBuilder builder = QueryBuilders.constantScoreQuery(
+ QueryBuilders.boolQuery()
+ .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId)));
+ SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+ SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+ executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+ ML_ORIGIN,
+ searchRequest,
+ ActionListener.wrap(
+ response -> {
+ if (response.getHits().getTotalHits().value > 0) {
+ listener.onFailure(
+ ExceptionsHelper.badRequestException(
+ Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId)));
+ return;
+ }
+ listener.onResponse(null);
+ },
+ listener::onFailure
+ ),
+ client::search);
+ }
+
+ private void checkTagsAgainstModelIds(List tags, ActionListener listener) {
+ if (tags.isEmpty()) {
+ listener.onResponse(null);
+ return;
+ }
+
+ QueryBuilder builder = QueryBuilders.constantScoreQuery(
+ QueryBuilders.boolQuery()
+ .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
+ SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+ SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+ executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+ ML_ORIGIN,
+ searchRequest,
+ ActionListener.wrap(
+ response -> {
+ if (response.getHits().getTotalHits().value > 0) {
+ listener.onFailure(
+ ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags)));
+ return;
+ }
+ listener.onResponse(null);
+ },
+ listener::onFailure
+ ),
+ client::search);
+ }
+
+ @Override
+ protected ClusterBlockException checkBlock(Request request, ClusterState state) {
+ return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
+ }
+
+ @Override
+ protected void doExecute(Task task, Request request, ActionListener listener) {
+ if (licenseState.isMachineLearningAllowed()) {
+ super.doExecute(task, request, listener);
+ } else {
+ listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
+ }
+ }
+}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
index 2c14fd70f10..d63dbf1bc4b 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
@@ -175,10 +175,12 @@ public class TrainedModelProvider {
r -> {
assert r.getItems().length == 2;
if (r.getItems()[0].isFailed()) {
+
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model config for inference",
trainedModelConfig.getModelId()),
r.getItems()[0].getFailure().getCause());
+
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
return;
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
new file mode 100644
index 00000000000..cb3f4e0edde
--- /dev/null
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestController;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.io.IOException;
+
+public class RestPutTrainedModelAction extends BaseRestHandler {
+
+ public RestPutTrainedModelAction(RestController controller) {
+ controller.registerHandler(RestRequest.Method.PUT,
+ MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}",
+ this);
+ }
+
+ @Override
+ public String getName() {
+ return "xpack_ml_put_trained_model_action";
+ }
+
+ @Override
+ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+ String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
+ XContentParser parser = restRequest.contentParser();
+ PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
+ putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));
+
+ return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
+ }
+}
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
index 2e6d6b21344..c2042811e8b 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
@@ -13,7 +13,6 @@ import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterState;
@@ -37,26 +36,30 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.core.ml.client.MachineLearningClient;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
@@ -561,12 +564,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"modelprocessorlicensetest\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
- " \"field_mappings\": {\n" +
- " \"col1\": \"col1\",\n" +
- " \"col2\": \"col2\",\n" +
- " \"col3\": \"col3\",\n" +
- " \"col4\": \"col4\"\n" +
- " }\n" +
+ " \"field_mappings\": {}\n" +
" }\n" +
" }]}\n";
// Creating a pipeline should work
@@ -748,76 +746,22 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
- private void putInferenceModel(String modelId) throws Exception {
- String config = "" +
- "{\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model for classification\",\n" +
- " \"version\": \"7.6.0\",\n" +
- " \"created_by\": \"benwtrent\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0,\n" +
- " \"estimated_operations\": 0,\n" +
- " \"created_time\": 0\n" +
- "}";
- String definition = "" +
- "{" +
- " \"trained_model\": {\n" +
- " \"tree\": {\n" +
- " \"feature_names\": [\n" +
- " \"col1_male\",\n" +
- " \"col1_female\",\n" +
- " \"col2_encoded\",\n" +
- " \"col3_encoded\",\n" +
- " \"col4\"\n" +
- " ],\n" +
- " \"tree_structure\": [\n" +
- " {\n" +
- " \"node_index\": 0,\n" +
- " \"split_feature\": 0,\n" +
- " \"split_gain\": 12.0,\n" +
- " \"threshold\": 10.0,\n" +
- " \"decision_type\": \"lte\",\n" +
- " \"default_left\": true,\n" +
- " \"left_child\": 1,\n" +
- " \"right_child\": 2\n" +
- " },\n" +
- " {\n" +
- " \"node_index\": 1,\n" +
- " \"leaf_value\": 1\n" +
- " },\n" +
- " {\n" +
- " \"node_index\": 2,\n" +
- " \"leaf_value\": 2\n" +
- " }\n" +
- " ],\n" +
- " \"target_type\": \"regression\"\n" +
- " }\n" +
- " }" +
- "}";
- String compressedDefinitionString =
- InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8)));
- String compressedDefinition = "" +
- "{" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" +
- " \"doc_num\": " + 0 + ",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" +
- " \"definition_length\": " + compressedDefinitionString.length() + ",\n" +
- " \"definition\": \"" + compressedDefinitionString + "\"\n" +
- "}";
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId(modelId)
- .setSource(config, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
- .setId(TrainedModelDefinitionDoc.docId(modelId, 0))
- .setSource(compressedDefinition, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
+ private void putInferenceModel(String modelId) {
+ TrainedModelConfig config = TrainedModelConfig.builder()
+ .setParsedDefinition(
+ new TrainedModelDefinition.Builder()
+ .setTrainedModel(
+ Tree.builder()
+ .setTargetType(TargetType.REGRESSION)
+ .setFeatureNames(Arrays.asList("feature1"))
+ .setNodes(TreeNode.builder(0).setLeafValue(1.0))
+ .build())
+ .setPreProcessors(Collections.emptyList()))
+ .setModelId(modelId)
+ .setDescription("test model for classification")
+ .setInput(new TrainedModelInput(Arrays.asList("feature1")))
+ .build();
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
}
private static OperationMode randomInvalidLicenseType() {
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
index b4c7b447c53..b518e01f99d 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
@@ -200,7 +200,6 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
-
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
new file mode 100644
index 00000000000..a58fa135407
--- /dev/null
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
@@ -0,0 +1,28 @@
+{
+ "ml.put_trained_model":{
+ "documentation":{
+ "url":"TODO"
+ },
+ "stability":"experimental",
+ "url":{
+ "paths":[
+ {
+ "path":"/_ml/inference/{model_id}",
+ "methods":[
+ "PUT"
+ ],
+ "parts":{
+ "model_id":{
+ "type":"string",
+ "description":"The ID of the trained models to store"
+ }
+ }
+ }
+ ]
+ },
+ "body": {
+ "description":"The trained model configuration",
+ "required":true
+ }
+ }
+}
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
index f72fd1120d8..f5f9a56bab8 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
@@ -1,3 +1,74 @@
+setup:
+ - skip:
+ features: headers
+ - do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+ ml.put_trained_model:
+ model_id: a-regression-model-0
+ body: >
+ {
+ "description": "empty model for tests",
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
+ }
+
+ - do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+ ml.put_trained_model:
+ model_id: a-regression-model-1
+ body: >
+ {
+ "description": "empty model for tests",
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
+ }
+ - do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+ ml.put_trained_model:
+ model_id: a-classification-model
+ body: >
+ {
+ "description": "empty model for tests",
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "classification",
+ "classification_labels": ["no", "yes"]
+ }
+ }
+ }
+ }
---
"Test get given missing trained model":
@@ -24,56 +95,52 @@
- match: { count: 0 }
- match: { trained_model_configs: [] }
---
+"Test get models":
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-classification-model" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
+ - match: { trained_model_configs.2.model_id: "a-regression-model-1" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "a-regression*"
+ - match: { count: 2 }
+ - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ from: 0
+ size: 2
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-classification-model" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ from: 1
+ size: 1
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+---
"Test delete given unused trained model":
-
- - do:
- index:
- id: trained_model_config-unused-regression-model-0
- index: .ml-inference-000001
- body: >
- {
- "model_id": "unused-regression-model",
- "created_by": "ml_tests",
- "version": "8.0.0",
- "description": "empty model for tests",
- "create_time": 0,
- "model_version": 0,
- "model_type": "local"
- }
- - do:
- indices.refresh: {}
-
- do:
ml.delete_trained_model:
- model_id: "unused-regression-model"
+ model_id: "a-classification-model"
- match: { acknowledged: true }
-
---
"Test delete with missing model":
- do:
catch: missing
ml.delete_trained_model:
model_id: "missing-trained-model"
-
---
"Test delete given used trained model":
- - do:
- index:
- id: trained_model_config-used-regression-model-0
- index: .ml-inference-000001
- body: >
- {
- "model_id": "used-regression-model",
- "created_by": "ml_tests",
- "version": "8.0.0",
- "description": "empty model for tests",
- "create_time": 0,
- "model_version": 0,
- "model_type": "local"
- }
- - do:
- indices.refresh: {}
-
- do:
ingest.put_pipeline:
id: "regression-model-pipeline"
@@ -82,7 +149,7 @@
"processors": [
{
"inference" : {
- "model_id" : "used-regression-model",
+ "model_id" : "a-regression-model-0",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
@@ -95,12 +162,12 @@
- do:
catch: conflict
ml.delete_trained_model:
- model_id: "used-regression-model"
+ model_id: "a-regression-model-0"
---
"Test get pre-packaged trained models":
- do:
ml.get_trained_models:
- model_id: "_all"
+ model_id: "lang_ident_model_1"
allow_no_match: false
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }