diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 4967d8091c9..2e077f547e3 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; +import org.elasticsearch.client.ml.PutTrainedModelRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; @@ -792,6 +793,16 @@ final class MLRequestConverters { return new Request(HttpDelete.METHOD_NAME, endpoint); } + static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "inference") + .addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId()) + .build(); + Request request = new Request(HttpPut.METHOD_NAME, endpoint); + request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 0a71b8ddb01..bdb2f22f3b3 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -100,6 +100,8 @@ import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutFilterResponse; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.PutJobResponse; +import org.elasticsearch.client.ml.PutTrainedModelRequest; +import org.elasticsearch.client.ml.PutTrainedModelResponse; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; @@ -2340,6 +2342,48 @@ public final class MachineLearningClient { Collections.emptySet()); } + /** + * Put trained model config + *

+ * For additional info + * see + * PUT Trained Model Config documentation + * + * @param request The {@link PutTrainedModelRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link PutTrainedModelResponse} response object + */ + public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::putTrainedModel, + options, + PutTrainedModelResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Put trained model config asynchronously and notifies listener upon completion + *

+ * For additional info + * see + * PUT Trained Model Config documentation + * + * @param request The {@link PutTrainedModelRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + * @return cancellable that may be used to cancel the request + */ + public Cancellable putTrainedModelAsync(PutTrainedModelRequest request, + RequestOptions options, + ActionListener listener) { + return restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::putTrainedModel, + options, + PutTrainedModelResponse::fromXContent, + listener, + Collections.emptySet()); + } + /** * Gets trained model stats *

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java new file mode 100644 index 00000000000..780ec31771b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + + +public class PutTrainedModelRequest implements Validatable, ToXContentObject { + + private final TrainedModelConfig config; + + public PutTrainedModelRequest(TrainedModelConfig config) { + this.config = config; + } + + public TrainedModelConfig getTrainedModelConfig() { + return config; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return config.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PutTrainedModelRequest request = (PutTrainedModelRequest) o; + return Objects.equals(config, request.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + + @Override + public final String toString() { + return Strings.toString(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java new file mode 100644 index 00000000000..3bc81f18129 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + + +public class PutTrainedModelResponse implements ToXContentObject { + + private final TrainedModelConfig trainedModelConfig; + + public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException { + return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build()); + } + + public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) { + this.trainedModelConfig = trainedModelConfig; + } + + public TrainedModelConfig getResponse() { + return trainedModelConfig; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return trainedModelConfig.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PutTrainedModelResponse response = (PutTrainedModelResponse) o; + return Objects.equals(trainedModelConfig, response.trainedModelConfig); + } + + @Override + public int hashCode() { + return Objects.hash(trainedModelConfig); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java new file mode 100644 index 00000000000..9bec4c4eb5d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java @@ -0,0 +1,81 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +/** + * Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP. + */ +public final class InferenceToXContentCompressor { + private static final int BUFFER_SIZE = 4096; + private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum + + private InferenceToXContentCompressor() {} + + public static String deflate(T objectToCompress) throws IOException { + BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false); + return deflate(reference); + } + + public static T inflate(String compressedString, + CheckedFunction parserFunction, + NamedXContentRegistry xContentRegistry) throws IOException { + try(XContentParser parser = XContentHelper.createParser(xContentRegistry, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + inflate(compressedString, MAX_INFLATED_BYTES), + XContentType.JSON)) { + return parserFunction.apply(parser); + } + } + + static BytesReference inflate(String compressedString, long streamSize) throws IOException { + byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); + InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize); + return Streams.readFully(inflateStream); + } + + private static String deflate(BytesReference reference) throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { + reference.writeTo(compressedOutput); + } + return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java new file mode 100644 index 00000000000..683e23dc9d7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java @@ -0,0 +1,68 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * This is a pared down bounded input stream. + * Only read is specifically enforced. + */ +final class SimpleBoundedInputStream extends InputStream { + + private final InputStream in; + private final long maxBytes; + private long numBytes; + + SimpleBoundedInputStream(InputStream inputStream, long maxBytes) { + this.in = Objects.requireNonNull(inputStream, "inputStream"); + if (maxBytes < 0) { + throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0"); + } + this.maxBytes = maxBytes; + } + + + /** + * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read. + * @return The byte read. -1 on internal stream completion or when maxBytes is exceeded. + * @throws IOException on failure + */ + @Override + public int read() throws IOException { + // We have reached the maximum, signal stream completion. + if (numBytes >= maxBytes) { + return -1; + } + numBytes++; + return in.read(); + } + + /** + * Delegates `close` to the wrapped InputStream + * @throws IOException on failure + */ + @Override + public void close() throws IOException { + in.close(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 23eb01fb3b1..9d2b323cf48 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -111,7 +112,7 @@ public class TrainedModelConfig implements ToXContentObject { this.modelId = modelId; this.createdBy = createdBy; this.version = version; - this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli()); + this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli()); this.definition = definition; this.compressedDefinition = compressedDefinition; this.description = description; @@ -293,12 +294,12 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - public Builder setCreatedBy(String createdBy) { + private Builder setCreatedBy(String createdBy) { this.createdBy = createdBy; return this; } - public Builder setVersion(Version version) { + private Builder setVersion(Version version) { this.version = version; return this; } @@ -312,7 +313,7 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - public Builder setCreateTime(Instant createTime) { + private Builder setCreateTime(Instant createTime) { this.createTime = createTime; return this; } @@ -322,6 +323,10 @@ public class TrainedModelConfig implements ToXContentObject { return this; } + public Builder setTags(String... tags) { + return setTags(Arrays.asList(tags)); + } + public Builder setMetadata(Map metadata) { this.metadata = metadata; return this; @@ -347,17 +352,17 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { + private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { this.estimatedHeapMemory = estimatedHeapMemory; return this; } - public Builder setEstimatedOperations(Long estimatedOperations) { + private Builder setEstimatedOperations(Long estimatedOperations) { this.estimatedOperations = estimatedOperations; return this; } - public Builder setLicenseLevel(String licenseLevel) { + private Builder setLicenseLevel(String licenseLevel) { this.licenseLevel = licenseLevel; return this; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java index 10f849cac48..9b19323023d 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -48,6 +49,10 @@ public class TrainedModelInput implements ToXContentObject { this.fieldNames = fieldNames; } + public TrainedModelInput(String... fieldNames) { + this(Arrays.asList(fieldNames)); + } + public static TrainedModelInput fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 24825dfc265..f5733ef3a0d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; +import org.elasticsearch.client.ml.PutTrainedModelRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; @@ -91,6 +92,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelConfigTests; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; import org.elasticsearch.client.ml.job.config.Job; @@ -874,6 +878,20 @@ public class MLRequestConvertersTests extends ESTestCase { assertNull(request.getEntity()); } + public void testPutTrainedModel() throws IOException { + TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig(); + PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig); + + Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest); + + assertEquals(HttpPut.METHOD_NAME, request.getMethod()); + assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId())); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build(); + assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig)); + } + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); @@ -1046,6 +1064,7 @@ public class MLRequestConvertersTests extends ESTestCase { namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 170fd4e858a..247b726e008 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -101,6 +101,8 @@ import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutFilterResponse; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.PutJobResponse; +import org.elasticsearch.client.ml.PutTrainedModelRequest; +import org.elasticsearch.client.ml.PutTrainedModelResponse; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; @@ -146,9 +148,12 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal import org.elasticsearch.client.ml.dataframe.explain.FieldSelection; import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; +import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor; +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; @@ -162,14 +167,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.process.ModelSnapshot; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -178,11 +181,9 @@ import org.elasticsearch.search.SearchHit; import org.junit.After; import java.io.IOException; -import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -190,7 +191,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import java.util.zip.GZIPOutputStream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; @@ -2222,6 +2222,50 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { } } + public void testPutTrainedModel() throws Exception { + String modelId = "test-put-trained-model"; + String modelIdCompressed = "test-put-trained-model-compressed-definition"; + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build(); + TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) + .setModelId(modelId) + .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) + .setDescription("test model") + .build(); + PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig), + machineLearningClient::putTrainedModel, + machineLearningClient::putTrainedModelAsync); + TrainedModelConfig createdModel = putTrainedModelResponse.getResponse(); + assertThat(createdModel.getModelId(), equalTo(modelId)); + + definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build(); + trainedModelConfig = TrainedModelConfig.builder() + .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) + .setModelId(modelIdCompressed) + .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) + .setDescription("test model") + .build(); + putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig), + machineLearningClient::putTrainedModel, + machineLearningClient::putTrainedModelAsync); + createdModel = putTrainedModelResponse.getResponse(); + assertThat(createdModel.getModelId(), equalTo(modelIdCompressed)); + + GetTrainedModelsResponse getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true), + machineLearningClient::getTrainedModels, + machineLearningClient::getTrainedModelsAsync); + + assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1)); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue())); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue()))); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed)); + } + public void testGetTrainedModelsStats() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String modelIdPrefix = "a-get-trained-model-stats-"; @@ -2504,56 +2548,13 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { private void putTrainedModel(String modelId) throws IOException { TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build(); - highLevelClient().index( - new IndexRequest(".ml-inference-000001") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(modelConfigString(modelId), XContentType.JSON) - .id(modelId), - RequestOptions.DEFAULT); - - highLevelClient().index( - new IndexRequest(".ml-inference-000001") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON) - .id("trained_model_definition_doc-" + modelId + "-0"), - RequestOptions.DEFAULT); - } - - private String compressDefinition(TrainedModelDefinition definition) throws IOException { - BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false); - BytesStreamOutput out = new BytesStreamOutput(); - try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) { - reference.writeTo(compressedOutput); - } - return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); - } - - private static String modelConfigString(String modelId) { - return "{\n" + - " \"doc_type\": \"trained_model_config\",\n" + - " \"model_id\": \"" + modelId + "\",\n" + - " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + - " \"description\": \"test model\",\n" + - " \"version\": \"7.6.0\",\n" + - " \"license_level\": \"platinum\",\n" + - " \"created_by\": \"ml_test\",\n" + - " \"estimated_heap_memory_usage_bytes\": 0," + - " \"estimated_operations\": 0," + - " \"created_time\": 0\n" + - "}"; - } - - private static String modelDocString(String compressedDefinition, String modelId) { - return "" + - "{" + - "\"model_id\": \"" + modelId + "\",\n" + - "\"doc_num\": 0,\n" + - "\"doc_type\": \"trained_model_definition_doc\",\n" + - " \"compression_version\": " + 1 + ",\n" + - " \"total_definition_length\": " + compressedDefinition.length() + ",\n" + - " \"definition_length\": " + compressedDefinition.length() + ",\n" + - "\"definition\": \"" + compressedDefinition + "\"\n" + - "}"; + TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) + .setModelId(modelId) + .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) + .setDescription("test model") + .build(); + highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT); } private void waitForJobToClose(String jobId) throws Exception { @@ -2798,4 +2799,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT); assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false)); } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 94863cbd5e6..860fe533fd3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutFilterResponse; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.PutJobResponse; +import org.elasticsearch.client.ml.PutTrainedModelRequest; +import org.elasticsearch.client.ml.PutTrainedModelResponse; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; @@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal import org.elasticsearch.client.ml.dataframe.explain.FieldSelection; import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; +import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor; +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId; import org.junit.After; import java.io.IOException; -import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; -import java.util.Base64; import java.util.Collections; import java.util.Date; import java.util.HashMap; @@ -216,7 +219,6 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import java.util.zip.GZIPOutputStream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.closeTo; @@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { } } + public void testPutTrainedModel() throws Exception { + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build(); + // tag::put-trained-model-config + TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) // <1> + .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2> + .setModelId("my-new-trained-model") // <3> + .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4> + .setDescription("test model") // <5> + .setMetadata(new HashMap<>()) // <6> + .setTags("my_regression_models") // <7> + .build(); + // end::put-trained-model-config + + trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) + .setModelId("my-new-trained-model") + .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) + .setDescription("test model") + .setMetadata(new HashMap<>()) + .setTags("my_regression_models") + .build(); + + RestHighLevelClient client = highLevelClient(); + { + // tag::put-trained-model-request + PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1> + // end::put-trained-model-request + + // tag::put-trained-model-execute + PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT); + // end::put-trained-model-execute + + // tag::put-trained-model-response + TrainedModelConfig model = response.getResponse(); + // end::put-trained-model-response + + assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId())); + highLevelClient().machineLearning() + .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT); + } + { + PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); + + // tag::put-trained-model-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(PutTrainedModelResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::put-trained-model-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::put-trained-model-execute-async + client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::put-trained-model-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + + highLevelClient().machineLearning() + .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT); + } + } + public void testGetTrainedModelsStats() throws Exception { putTrainedModel("my-trained-model"); RestHighLevelClient client = highLevelClient(); @@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { } private void putTrainedModel(String modelId) throws IOException { - TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); - highLevelClient().index( - new IndexRequest(".ml-inference-000001") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(modelConfigString(modelId), XContentType.JSON) - .id(modelId), - RequestOptions.DEFAULT); - - highLevelClient().index( - new IndexRequest(".ml-inference-000001") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON) - .id("trained_model_definition_doc-" + modelId + "-0"), - RequestOptions.DEFAULT); + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build(); + TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() + .setDefinition(definition) + .setModelId(modelId) + .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) + .setDescription("test model") + .build(); + highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT); } - private String compressDefinition(TrainedModelDefinition definition) throws IOException { - BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false); - BytesStreamOutput out = new BytesStreamOutput(); - try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) { - reference.writeTo(compressedOutput); - } - return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); - } - - private static String modelConfigString(String modelId) { - return "{\n" + - " \"doc_type\": \"trained_model_config\",\n" + - " \"model_id\": \"" + modelId + "\",\n" + - " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + - " \"description\": \"test model for\",\n" + - " \"version\": \"7.6.0\",\n" + - " \"license_level\": \"platinum\",\n" + - " \"created_by\": \"ml_test\",\n" + - " \"estimated_heap_memory_usage_bytes\": 0," + - " \"estimated_operations\": 0," + - " \"created_time\": 0\n" + - "}"; - } - - private static String modelDocString(String compressedDefinition, String modelId) { - return "" + - "{" + - "\"model_id\": \"" + modelId + "\",\n" + - "\"doc_num\": 0,\n" + - "\"doc_type\": \"trained_model_definition_doc\",\n" + - " \"compression_version\": " + 1 + ",\n" + - " \"total_definition_length\": " + compressedDefinition.length() + ",\n" + - " \"definition_length\": " + compressedDefinition.length() + ",\n" + - "\"definition\": \"" + compressedDefinition + "\"\n" + - "}"; + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); } private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG = diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java new file mode 100644 index 00000000000..b3956c5c6af --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java @@ -0,0 +1,52 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelConfigTests; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase { + + @Override + protected PutTrainedModelRequest createTestInstance() { + return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig()); + } + + @Override + protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException { + return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build()); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java new file mode 100644 index 00000000000..61e1638547b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java @@ -0,0 +1,52 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelConfigTests; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase { + + @Override + protected PutTrainedModelResponse createTestInstance() { + return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig()); + } + + @Override + protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException { + return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build()); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java new file mode 100644 index 00000000000..11747638a2c --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceToXContentCompressorTests extends ESTestCase { + + public void testInflateAndDeflate() throws IOException { + for(int i = 0; i < 10; i++) { + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); + String firstDeflate = InferenceToXContentCompressor.deflate(definition); + TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate, + parser -> TrainedModelDefinition.fromXContent(parser).build(), + xContentRegistry()); + + // Did we inflate to the same object? + assertThat(inflatedDefinition, equalTo(definition)); + } + } + + public void testInflateTooLargeStream() throws IOException { + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); + String firstDeflate = InferenceToXContentCompressor.deflate(definition); + BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L); + assertThat(inflatedBytes.length(), equalTo(10)); + try(XContentParser parser = XContentHelper.createParser(xContentRegistry(), + LoggingDeprecationHandler.INSTANCE, + inflatedBytes, + XContentType.JSON)) { + expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser)); + } + } + + public void testInflateGarbage() { + expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 51f9692b8b6..81c64b3a22c 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -37,6 +37,24 @@ import java.util.stream.Stream; public class TrainedModelConfigTests extends AbstractXContentTestCase { + public static TrainedModelConfig createTestTrainedModelConfig() { + return new TrainedModelConfig( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + Version.CURRENT, + randomBoolean() ? null : randomAlphaOfLength(100), + Instant.ofEpochMilli(randomNonNegativeLong()), + randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), + randomBoolean() ? null : randomAlphaOfLength(100), + randomBoolean() ? null : + Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomBoolean() ? null : TrainedModelInputTests.createRandomInput(), + randomBoolean() ? null : randomNonNegativeLong(), + randomBoolean() ? null : randomNonNegativeLong(), + randomBoolean() ? null : randomFrom("platinum", "basic")); + } + @Override protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException { return TrainedModelConfig.fromXContent(parser); @@ -54,22 +72,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), - randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : TrainedModelInputTests.createRandomInput(), - randomBoolean() ? null : randomNonNegativeLong(), - randomBoolean() ? null : randomNonNegativeLong(), - randomBoolean() ? null : randomFrom("platinum", "basic")); - + return createTestTrainedModelConfig(); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 72a640d609b..5c857bb625b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -67,15 +67,17 @@ public class EnsembleTests extends AbstractXContentTestCase { .collect(Collectors.toList()); int numberOfModels = randomIntBetween(1, 10); List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType)) - .limit(numberOfFeatures) + .limit(numberOfModels) .collect(Collectors.toList()); - OutputAggregator outputAggregator = null; - if (randomBoolean()) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights)); + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + List possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights), + new LogisticRegression(weights))); + if (targetType.equals(TargetType.REGRESSION)) { + possibleAggregators.add(new WeightedSum(weights)); } + OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0])); List categoryLabels = null; - if (randomBoolean()) { + if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } return new Ensemble(featureNames, diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java index 748dc982e67..5e105dffd8b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -84,7 +84,7 @@ public class TreeTests extends AbstractXContentTestCase { childNodes = nextNodes; } List categoryLabels = null; - if (randomBoolean()) { + if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } return builder.setClassificationLabels(categoryLabels) diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc new file mode 100644 index 00000000000..dadc8dcf65a --- /dev/null +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -0,0 +1,53 @@ +-- +:api: put-trained-model +:request: PutTrainedModelRequest +:response: PutTrainedModelResponse +-- +[role="xpack"] +[id="{upid}-{api}"] +=== Put Trained Model API + +Creates a new trained model for inference. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Put Trained Model request + +A +{request}+ requires the following argument: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> The configuration of the {infer} Trained Model to create + +[id="{upid}-{api}-config"] +==== Trained Model configuration + +The `TrainedModelConfig` object contains all the details about the trained model +configuration and contains the following arguments: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-config] +-------------------------------------------------- +<1> The {infer} definition for the model +<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport. + Do not supply both the compressed and uncompressed definitions. +<3> The unique model id +<4> The input field names for the model definition +<5> Optionally, a human-readable description +<6> Optionally, an object map contain metadata about the model +<7> Optionally, an array of tags to organize the model + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the newly created trained model. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index e0d228b5d1e..4b848819702 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-evaluate-data-frame>> * <<{upid}-explain-data-frame-analytics>> * <<{upid}-get-trained-models>> +* <<{upid}-put-trained-model>> * <<{upid}-get-trained-models-stats>> * <<{upid}-delete-trained-model>> * <<{upid}-put-filter>> @@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[] include::ml/evaluate-data-frame.asciidoc[] include::ml/explain-data-frame-analytics.asciidoc[] include::ml/get-trained-models.asciidoc[] +include::ml/put-trained-model.asciidoc[] include::ml/get-trained-models-stats.asciidoc[] include::ml/delete-trained-model.asciidoc[] include::ml/put-filter.asciidoc[] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index ca0227e84ae..843e3e611df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -126,6 +126,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; @@ -381,6 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl GetTrainedModelsAction.INSTANCE, DeleteTrainedModelAction.INSTANCE, GetTrainedModelsStatsAction.INSTANCE, + PutTrainedModelAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java new file mode 100644 index 00000000000..06fbb6401a0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import java.io.IOException; +import java.util.Objects; + + +public class PutTrainedModelAction extends ActionType { + + public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/put"; + private PutTrainedModelAction() { + super(NAME, Response::new); + } + + public static class Request extends AcknowledgedRequest { + + public static Request parseRequest(String modelId, XContentParser parser) { + TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null); + + if (builder.getModelId() == null) { + builder.setModelId(modelId).build(); + } else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) { + // If we have model_id in both URI and body, they must be identical + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, + TrainedModelConfig.MODEL_ID.getPreferredName(), + builder.getModelId(), + modelId)); + } + // Validations are done against the builder so we can build the full config object. + // This allows us to not worry about serializing a builder class between nodes. + return new Request(builder.validate(true).build()); + } + + private final TrainedModelConfig config; + + public Request(TrainedModelConfig config) { + this.config = config; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.config = new TrainedModelConfig(in); + } + + public TrainedModelConfig getTrainedModelConfig() { + return config; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(config, request.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + + @Override + public final String toString() { + return Strings.toString(config); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private final TrainedModelConfig trainedModelConfig; + + public Response(TrainedModelConfig trainedModelConfig) { + this.trainedModelConfig = trainedModelConfig; + } + + public Response(StreamInput in) throws IOException { + super(in); + trainedModelConfig = new TrainedModelConfig(in); + } + + public TrainedModelConfig getResponse() { + return trainedModelConfig; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + trainedModelConfig.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return trainedModelConfig.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(trainedModelConfig, response.trainedModelConfig); + } + + @Override + public int hashCode() { + return Objects.hash(trainedModelConfig); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index be4d40efc85..95589ac8b61 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; @@ -34,6 +35,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.action.ValidateActions.addValidationError; public class TrainedModelConfig implements ToXContentObject, Writeable { @@ -352,13 +356,31 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private Long estimatedHeapMemory; private Long estimatedOperations; private LazyModelDefinition definition; - private String licenseLevel = License.OperationMode.PLATINUM.description(); + private String licenseLevel; + + public Builder() {} + + public Builder(TrainedModelConfig config) { + this.modelId = config.getModelId(); + this.createdBy = config.getCreatedBy(); + this.version = config.getVersion(); + this.createTime = config.getCreateTime(); + this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition); + this.description = config.getDescription(); + this.tags = config.getTags(); + this.metadata = config.getMetadata(); + this.input = config.getInput(); + } public Builder setModelId(String modelId) { this.modelId = modelId; return this; } + public String getModelId() { + return this.modelId; + } + public Builder setCreatedBy(String createdBy) { this.createdBy = createdBy; return this; @@ -466,51 +488,96 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } - // TODO move to REST level instead of here in the builder - public void validate() { - // We require a definition to be available here even though it will be stored in a different doc - ExceptionsHelper.requireNonNull(definition, DEFINITION); - ExceptionsHelper.requireNonNull(modelId, MODEL_ID); - - if (MlStrings.isValidId(modelId) == false) { - throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId)); - } - - if (MlStrings.hasValidLengthForId(modelId) == false) { - throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG, - MODEL_ID.getPreferredName(), - modelId, - MlStrings.ID_LENGTH_LIMIT)); - } - - checkIllegalSetting(version, VERSION.getPreferredName()); - checkIllegalSetting(createdBy, CREATED_BY.getPreferredName()); - checkIllegalSetting(createTime, CREATE_TIME.getPreferredName()); - checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()); - checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName()); - checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName()); + public Builder validate() { + return validate(false); } - private static void checkIllegalSetting(Object value, String setting) { - if (value != null) { - throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting); + /** + * Runs validations against the builder. + * @return The current builder object if validations are successful + * @throws ActionRequestValidationException when there are validation failures. + */ + public Builder validate(boolean forCreation) { + // We require a definition to be available here even though it will be stored in a different doc + ActionRequestValidationException validationException = null; + if (definition == null) { + validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException); } + if (modelId == null) { + validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException); + } + + if (modelId != null && MlStrings.isValidId(modelId) == false) { + validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID, + TrainedModelConfig.MODEL_ID.getPreferredName(), + modelId), + validationException); + } + if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) { + validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG, + TrainedModelConfig.MODEL_ID.getPreferredName(), + modelId, + MlStrings.ID_LENGTH_LIMIT), validationException); + } + List badTags = tags.stream() + .filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false) + .collect(Collectors.toList()); + if (badTags.isEmpty() == false) { + validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS, + badTags, + MlStrings.ID_LENGTH_LIMIT), + validationException); + } + + for(String tag : tags) { + if (tag.equals(modelId)) { + validationException = addValidationError("none of the tags must equal the model_id", validationException); + break; + } + } + if (forCreation) { + validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException); + validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException); + validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException); + validationException = checkIllegalSetting(estimatedHeapMemory, + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + validationException); + validationException = checkIllegalSetting(estimatedOperations, + ESTIMATED_OPERATIONS.getPreferredName(), + validationException); + validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException); + } + + if (validationException != null) { + throw validationException; + } + + return this; + } + + private static ActionRequestValidationException checkIllegalSetting(Object value, + String setting, + ActionRequestValidationException validationException) { + if (value != null) { + return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException); + } + return validationException; } public TrainedModelConfig build() { return new TrainedModelConfig( modelId, - createdBy, - version, + createdBy == null ? "user" : createdBy, + version == null ? Version.CURRENT : version, description, createTime == null ? Instant.now() : createTime, definition, tags, metadata, input, - estimatedHeapMemory, - estimatedOperations, - licenseLevel); + estimatedHeapMemory == null ? 0 : estimatedHeapMemory, + estimatedOperations == null ? 0 : estimatedOperations, + licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel); } } @@ -531,6 +598,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return new LazyModelDefinition(input.readString(), null); } + private LazyModelDefinition(LazyModelDefinition definition) { + if (definition != null) { + this.compressedString = definition.compressedString; + this.parsedDefinition = definition.parsedDefinition; + } + } + private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) { if (compressedString == null && trainedModelDefinition == null) { throw new IllegalArgumentException("unexpected null model definition"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index e176d9a2885..cf7a8b7d224 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -179,6 +179,12 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco this(true); } + public Builder(TrainedModelDefinition definition) { + this(true); + this.preProcessors = new ArrayList<>(definition.getPreProcessors()); + this.trainedModel = definition.trainedModel; + } + public Builder setPreProcessors(List preProcessors) { this.preProcessors = preProcessors; return this; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 1b80d196359..ef0fcd4fdb1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -95,6 +95,10 @@ public final class Messages { public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED = "Getting model definition is not supported when getting more than one model"; public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing"; + public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " + + "hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters."; + public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids."; + public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags."; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java new file mode 100644 index 00000000000..0c39469c902 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; + +public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + String modelId = randomAlphaOfLength(10); + return new Request(TrainedModelConfigTests.createTestInstance(modelId) + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .build()); + } + + @Override + protected Writeable.Reader instanceReader() { + return (in) -> { + Request request = new Request(in); + request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry()); + return request; + }; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java new file mode 100644 index 00000000000..5813b13c8ad --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; + +public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + String modelId = randomAlphaOfLength(10); + return new Response(TrainedModelConfigTests.createTestInstance(modelId) + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .build()); + } + + @Override + protected Writeable.Reader instanceReader() { + return (in) -> { + Response response = new Response(in); + response.getResponse().ensureParsedDefinition(xContentRegistry()); + return response; + }; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 61a3b960f56..81011262012 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -5,8 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; @@ -56,14 +56,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase TrainedModelConfig.builder().validate()); - assertThat(ex.getMessage(), equalTo("[definition] must not be null.")); + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, + () -> TrainedModelConfig.builder().validate()); + assertThat(ex.getMessage(), containsString("[definition] must not be null.")); } public void testValidateWithInvalidID() { String modelId = "InvalidID-"; - ElasticsearchException ex = expectThrows(ElasticsearchException.class, + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder() .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); + assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); } public void testValidateWithLongID() { String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); - ElasticsearchException ex = expectThrows(ElasticsearchException.class, + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder() .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); + assertThat(ex.getMessage(), + containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); } public void testValidateWithIllegallyUserProvidedFields() { String modelId = "simplemodel"; - ElasticsearchException ex = expectThrows(ElasticsearchException.class, + ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder() .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setCreateTime(Instant.now()) - .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation")); + .setModelId(modelId).validate(true)); + assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation")); - ex = expectThrows(ElasticsearchException.class, + ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder() .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setVersion(Version.CURRENT) - .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); + .setModelId(modelId).validate(true)); + assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation")); - ex = expectThrows(ElasticsearchException.class, + ex = expectThrows(ActionRequestValidationException.class, () -> TrainedModelConfig.builder() .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setCreatedBy("ml_user") - .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); + .setModelId(modelId).validate(true)); + assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation")); } public void testSerializationWithLazyDefinition() throws IOException { diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 8c9ba6df7f0..7e19a4d606d 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -133,7 +133,6 @@ integTest.runner { 'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id', 'ml/get_datafeeds/Test get datafeed given missing datafeed_id', 'ml/inference_crud/Test delete given used trained model', - 'ml/inference_crud/Test delete given unused trained model', 'ml/inference_crud/Test delete with missing model', 'ml/inference_crud/Test get given missing trained model', 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 20aea0f6316..545ca247a4a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -9,16 +9,20 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; -import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; -import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.junit.After; import org.junit.Before; import java.io.IOException; @@ -35,26 +39,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { @Before public void createBothModels() throws Exception { - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId("test_classification") - .setSource(CLASSIFICATION_CONFIG, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId(TrainedModelDefinitionDoc.docId("test_classification", 0)) - .setSource(buildClassificationModelDoc(), XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId("test_regression") - .setSource(REGRESSION_CONFIG, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId(TrainedModelDefinitionDoc.docId("test_regression", 0)) - .setSource(buildRegressionModelDoc(), XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); + client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet(); + client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet(); + } + + @After + public void deleteBothModels() { + client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet(); + client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet(); } public void testPipelineCreationAndDeletion() throws Exception { @@ -392,6 +384,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + " \"version\": \"7.6.0\",\n" + + " \"definition\": " + REGRESSION_DEFINITION + ","+ " \"license_level\": \"platinum\",\n" + " \"created_by\": \"ml_test\",\n" + " \"estimated_heap_memory_usage_bytes\": 0," + @@ -519,28 +512,27 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " }\n" + "}"; - private static String buildClassificationModelDoc() throws IOException { - String compressed = - InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8))); - return modelDocString(compressed, "test_classification"); + private TrainedModelConfig buildClassificationModel() throws IOException { + try (XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(CLASSIFICATION_CONFIG), + XContentType.JSON)) { + return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); + } } - private static String buildRegressionModelDoc() throws IOException { - String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8))); - return modelDocString(compressed, "test_regression"); + private TrainedModelConfig buildRegressionModel() throws IOException { + try (XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(REGRESSION_CONFIG), + XContentType.JSON)) { + return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); + } } - private static String modelDocString(String compressedDefinition, String modelId) { - return "" + - "{" + - "\"model_id\": \"" + modelId + "\",\n" + - "\"doc_num\": 0,\n" + - "\"doc_type\": \"trained_model_definition_doc\",\n" + - " \"compression_version\": " + 1 + ",\n" + - " \"total_definition_length\": " + compressedDefinition.length() + ",\n" + - " \"definition_length\": " + compressedDefinition.length() + ",\n" + - "\"definition\": \"" + compressedDefinition + "\"\n" + - "}"; + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); } private static final String CLASSIFICATION_CONFIG = "" + @@ -549,8 +541,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + " \"version\": \"7.6.0\",\n" + + " \"definition\": " + CLASSIFICATION_DEFINITION + ","+ " \"license_level\": \"platinum\",\n" + - " \"created_by\": \"benwtrent\",\n" + + " \"created_by\": \"es_test\",\n" + " \"estimated_heap_memory_usage_bytes\": 0," + " \"estimated_operations\": 0," + " \"created_time\": 0\n" + diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 72c677d1aa4..0aec6bc3374 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -6,10 +6,18 @@ package org.elasticsearch.xpack.ml.integration; import org.apache.http.util.EntityUtils; -import org.elasticsearch.Version; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelDefinition; +import org.elasticsearch.client.ml.inference.TrainedModelInput; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -18,26 +26,19 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.license.License; import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.rest.ESRestTestCase; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; -import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import org.junit.After; import java.io.IOException; -import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.List; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.hamcrest.Matchers.containsString; @@ -62,22 +63,8 @@ public class TrainedModelIT extends ESRestTestCase { public void testGetTrainedModels() throws IOException { String modelId = "a_test_regression_model"; String modelId2 = "a_test_regression_model-2"; - Request model1 = new Request("PUT", - InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); - model1.setJsonEntity(buildRegressionModel(modelId)); - assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); - - Request modelDefinition1 = new Request("PUT", - InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0)); - modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId)); - assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); - - Request model2 = new Request("PUT", - InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); - model2.setJsonEntity(buildRegressionModel(modelId2)); - assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201)); - - adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + putRegressionModel(modelId); + putRegressionModel(modelId2); Response getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/" + modelId)); @@ -164,17 +151,7 @@ public class TrainedModelIT extends ESRestTestCase { public void testDeleteTrainedModels() throws IOException { String modelId = "test_delete_regression_model"; - Request model1 = new Request("PUT", - InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); - model1.setJsonEntity(buildRegressionModel(modelId)); - assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); - - Request modelDefinition1 = new Request("PUT", - InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0)); - modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId)); - assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); - - adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + putRegressionModel(modelId); Response delModel = client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId)); @@ -208,42 +185,68 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(response, containsString("\"definition\"")); } - private static String buildRegressionModel(String modelId) throws IOException { + private void putRegressionModel(String modelId) throws IOException { try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(buildRegression()); TrainedModelConfig.builder() + .setDefinition(definition) .setModelId(modelId) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) - .setCreatedBy("ml_test") - .setVersion(Version.CURRENT) - .setCreateTime(Instant.now()) - .setEstimatedOperations(0) - .setLicenseLevel(License.OperationMode.PLATINUM.description()) - .setEstimatedHeapMemory(0) - .build() - .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); - return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + .build().toXContent(builder, ToXContent.EMPTY_PARAMS); + Request model = new Request("PUT", "_ml/inference/" + modelId); + model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON)); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); } } - private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException { - try(XContentBuilder builder = XContentFactory.jsonBuilder()) { - TrainedModelDefinition definition = new TrainedModelDefinition.Builder() - .setPreProcessors(Collections.emptyList()) - .setTrainedModel(LocalModelTests.buildRegression()) - .build(); - String compressedString = InferenceToXContentCompressor.deflate(definition); - TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0) - .setCompressedString(compressedString) - .setTotalDefinitionLength(compressedString.length()) - .setDefinitionLength(compressedString.length()) - .setCompressionVersion(1) - .setModelId(modelId).build(); - doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); - return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); - } + private static TrainedModel buildRegression() { + List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5), + TreeNode.builder(1).setLeafValue(0.3), + TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4), + TreeNode.builder(3).setLeafValue(0.1), + TreeNode.builder(4).setLeafValue(0.2)) + .build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0), + TreeNode.builder(1).setLeafValue(1.5), + TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2), + TreeNode.builder(1).setLeafValue(1.5), + TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .build(); } - @After public void clearMlState() throws Exception { new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 1a12dad3114..1ce8566cf8c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -112,6 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; @@ -183,6 +184,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportPutFilterAction; import org.elasticsearch.xpack.ml.action.TransportPutJobAction; +import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; @@ -273,6 +275,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; +import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -773,7 +776,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu new RestExplainDataFrameAnalyticsAction(restController), new RestGetTrainedModelsAction(restController), new RestDeleteTrainedModelAction(restController), - new RestGetTrainedModelsStatsAction(restController) + new RestGetTrainedModelsStatsAction(restController), + new RestPutTrainedModelAction(restController) ); } @@ -844,7 +848,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class), new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class), new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), - new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class) + new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), + new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java new file mode 100644 index 00000000000..a520f621672 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportPutTrainedModelAction extends TransportMasterNodeAction { + + private final TrainedModelProvider trainedModelProvider; + private final XPackLicenseState licenseState; + private final NamedXContentRegistry xContentRegistry; + private final Client client; + + @Inject + public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService, + ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, Client client, + TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) { + super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new, + indexNameExpressionResolver); + this.licenseState = licenseState; + this.trainedModelProvider = trainedModelProvider; + this.xContentRegistry = xContentRegistry; + this.client = client; + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected Response read(StreamInput in) throws IOException { + return new Response(in); + } + + @Override + protected void masterOperation(Request request, ClusterState state, ActionListener listener) { + try { + request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry); + request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate(); + } catch (IOException ex) { + listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]", + ex, + request.getTrainedModelConfig().getModelId())); + return; + } catch (ElasticsearchException ex) { + listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", + ex, + request.getTrainedModelConfig().getModelId())); + return; + } + + TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig()) + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .setCreatedBy("api_user") + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()) + .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()) + .build(); + + ActionListener tagsModelIdCheckListener = ActionListener.wrap( + r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( + storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), + listener::onFailure + )), + listener::onFailure + ); + + ActionListener modelIdTagCheckListener = ActionListener.wrap( + r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener), + listener::onFailure + ); + + checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener); + } + + private void checkModelIdAgainstTags(String modelId, ActionListener listener) { + QueryBuilder builder = QueryBuilders.constantScoreQuery( + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId))); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1); + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + if (response.getHits().getTotalHits().value > 0) { + listener.onFailure( + ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId))); + return; + } + listener.onResponse(null); + }, + listener::onFailure + ), + client::search); + } + + private void checkTagsAgainstModelIds(List tags, ActionListener listener) { + if (tags.isEmpty()) { + listener.onResponse(null); + return; + } + + QueryBuilder builder = QueryBuilders.constantScoreQuery( + QueryBuilders.boolQuery() + .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags))); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1); + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + if (response.getHits().getTotalHits().value > 0) { + listener.onFailure( + ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags))); + return; + } + listener.onResponse(null); + }, + listener::onFailure + ), + client::search); + } + + @Override + protected ClusterBlockException checkBlock(Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + if (licenseState.isMachineLearningAllowed()) { + super.doExecute(task, request, listener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 2c14fd70f10..d63dbf1bc4b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -175,10 +175,12 @@ public class TrainedModelProvider { r -> { assert r.getItems().length == 2; if (r.getItems()[0].isFailed()) { + logger.error(new ParameterizedMessage( "[{}] failed to store trained model config for inference", trainedModelConfig.getModelId()), r.getItems()[0].getFailure().getCause()); + wrappedListener.onFailure(r.getItems()[0].getFailure().getCause()); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java new file mode 100644 index 00000000000..cb3f4e0edde --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestPutTrainedModelAction extends BaseRestHandler { + + public RestPutTrainedModelAction(RestController controller) { + controller.registerHandler(RestRequest.Method.PUT, + MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", + this); + } + + @Override + public String getName() { + return "xpack_ml_put_trained_model_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + XContentParser parser = restRequest.contentParser(); + PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser); + putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout())); + + return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 2e6d6b21344..c2042811e8b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.ingest.SimulatePipelineAction; import org.elasticsearch.action.ingest.SimulatePipelineRequest; import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.ClusterState; @@ -37,26 +36,30 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.client.MachineLearningClient; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; -import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; -import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import org.junit.Before; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Collections; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -561,12 +564,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { " \"target_field\": \"regression_value\",\n" + " \"model_id\": \"modelprocessorlicensetest\",\n" + " \"inference_config\": {\"regression\": {}},\n" + - " \"field_mappings\": {\n" + - " \"col1\": \"col1\",\n" + - " \"col2\": \"col2\",\n" + - " \"col3\": \"col3\",\n" + - " \"col4\": \"col4\"\n" + - " }\n" + + " \"field_mappings\": {}\n" + " }\n" + " }]}\n"; // Creating a pipeline should work @@ -748,76 +746,22 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); } - private void putInferenceModel(String modelId) throws Exception { - String config = "" + - "{\n" + - " \"model_id\": \"" + modelId + "\",\n" + - " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + - " \"description\": \"test model for classification\",\n" + - " \"version\": \"7.6.0\",\n" + - " \"created_by\": \"benwtrent\",\n" + - " \"license_level\": \"platinum\",\n" + - " \"estimated_heap_memory_usage_bytes\": 0,\n" + - " \"estimated_operations\": 0,\n" + - " \"created_time\": 0\n" + - "}"; - String definition = "" + - "{" + - " \"trained_model\": {\n" + - " \"tree\": {\n" + - " \"feature_names\": [\n" + - " \"col1_male\",\n" + - " \"col1_female\",\n" + - " \"col2_encoded\",\n" + - " \"col3_encoded\",\n" + - " \"col4\"\n" + - " ],\n" + - " \"tree_structure\": [\n" + - " {\n" + - " \"node_index\": 0,\n" + - " \"split_feature\": 0,\n" + - " \"split_gain\": 12.0,\n" + - " \"threshold\": 10.0,\n" + - " \"decision_type\": \"lte\",\n" + - " \"default_left\": true,\n" + - " \"left_child\": 1,\n" + - " \"right_child\": 2\n" + - " },\n" + - " {\n" + - " \"node_index\": 1,\n" + - " \"leaf_value\": 1\n" + - " },\n" + - " {\n" + - " \"node_index\": 2,\n" + - " \"leaf_value\": 2\n" + - " }\n" + - " ],\n" + - " \"target_type\": \"regression\"\n" + - " }\n" + - " }" + - "}"; - String compressedDefinitionString = - InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8))); - String compressedDefinition = "" + - "{" + - " \"model_id\": \"" + modelId + "\",\n" + - " \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" + - " \"doc_num\": " + 0 + ",\n" + - " \"compression_version\": " + 1 + ",\n" + - " \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" + - " \"definition_length\": " + compressedDefinitionString.length() + ",\n" + - " \"definition\": \"" + compressedDefinitionString + "\"\n" + - "}"; - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId(modelId) - .setSource(config, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) - .setId(TrainedModelDefinitionDoc.docId(modelId, 0)) - .setSource(compressedDefinition, XContentType.JSON) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .get().status(), equalTo(RestStatus.CREATED)); + private void putInferenceModel(String modelId) { + TrainedModelConfig config = TrainedModelConfig.builder() + .setParsedDefinition( + new TrainedModelDefinition.Builder() + .setTrainedModel( + Tree.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(Arrays.asList("feature1")) + .setNodes(TreeNode.builder(0).setLeafValue(1.0)) + .build()) + .setPreProcessors(Collections.emptyList())) + .setModelId(modelId) + .setDescription("test model for classification") + .setInput(new TrainedModelInput(Arrays.asList("feature1"))) + .build(); + client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); } private static OperationMode randomInvalidLicenseType() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index b4c7b447c53..b518e01f99d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -200,7 +200,6 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder()) - .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json new file mode 100644 index 00000000000..a58fa135407 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json @@ -0,0 +1,28 @@ +{ + "ml.put_trained_model":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "PUT" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models to store" + } + } + } + ] + }, + "body": { + "description":"The trained model configuration", + "required":true + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index f72fd1120d8..f5f9a56bab8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -1,3 +1,74 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-regression-model-0 + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "regression" + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-regression-model-1 + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "regression" + } + } + } + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-classification-model + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "classification", + "classification_labels": ["no", "yes"] + } + } + } + } --- "Test get given missing trained model": @@ -24,56 +95,52 @@ - match: { count: 0 } - match: { trained_model_configs: [] } --- +"Test get models": + - do: + ml.get_trained_models: + model_id: "*" + - match: { count: 4 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-0" } + - match: { trained_model_configs.2.model_id: "a-regression-model-1" } + + - do: + ml.get_trained_models: + model_id: "a-regression*" + - match: { count: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 0 + size: 2 + - match: { count: 4 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-0" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 1 + size: 1 + - match: { count: 4 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } +--- "Test delete given unused trained model": - - - do: - index: - id: trained_model_config-unused-regression-model-0 - index: .ml-inference-000001 - body: > - { - "model_id": "unused-regression-model", - "created_by": "ml_tests", - "version": "8.0.0", - "description": "empty model for tests", - "create_time": 0, - "model_version": 0, - "model_type": "local" - } - - do: - indices.refresh: {} - - do: ml.delete_trained_model: - model_id: "unused-regression-model" + model_id: "a-classification-model" - match: { acknowledged: true } - --- "Test delete with missing model": - do: catch: missing ml.delete_trained_model: model_id: "missing-trained-model" - --- "Test delete given used trained model": - - do: - index: - id: trained_model_config-used-regression-model-0 - index: .ml-inference-000001 - body: > - { - "model_id": "used-regression-model", - "created_by": "ml_tests", - "version": "8.0.0", - "description": "empty model for tests", - "create_time": 0, - "model_version": 0, - "model_type": "local" - } - - do: - indices.refresh: {} - - do: ingest.put_pipeline: id: "regression-model-pipeline" @@ -82,7 +149,7 @@ "processors": [ { "inference" : { - "model_id" : "used-regression-model", + "model_id" : "a-regression-model-0", "inference_config": {"regression": {}}, "target_field": "regression_field", "field_mappings": {} @@ -95,12 +162,12 @@ - do: catch: conflict ml.delete_trained_model: - model_id: "used-regression-model" + model_id: "a-regression-model-0" --- "Test get pre-packaged trained models": - do: ml.get_trained_models: - model_id: "_all" + model_id: "lang_ident_model_1" allow_no_match: false - match: { count: 1 } - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }