diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 9df33b410a7..2fc23acd134 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -58,6 +58,7 @@ import org.elasticsearch.client.ml.GetJobStatsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetRecordsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.MlInfoRequest; import org.elasticsearch.client.ml.OpenJobRequest; import org.elasticsearch.client.ml.PostCalendarEventRequest; @@ -709,6 +710,38 @@ final class MLRequestConverters { return request; } + static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "inference") + .addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIds())) + .build(); + RequestConverters.Params params = new RequestConverters.Params(); + if (getTrainedModelsRequest.getPageParams() != null) { + PageParams pageParams = getTrainedModelsRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getTrainedModelsRequest.getAllowNoMatch() != null) { + params.putParam(GetTrainedModelsRequest.ALLOW_NO_MATCH, + Boolean.toString(getTrainedModelsRequest.getAllowNoMatch())); + } + if (getTrainedModelsRequest.getDecompressDefinition() != null) { + params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION, + Boolean.toString(getTrainedModelsRequest.getDecompressDefinition())); + } + if (getTrainedModelsRequest.getIncludeDefinition() != null) { + params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION, + Boolean.toString(getTrainedModelsRequest.getIncludeDefinition())); + } + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + request.addParameters(params.asMap()); + return request; + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 62619303685..2ddc8839f96 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -73,6 +73,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.MlInfoRequest; import org.elasticsearch.client.ml.MlInfoResponse; import org.elasticsearch.client.ml.OpenJobRequest; @@ -2290,4 +2292,48 @@ public final class MachineLearningClient { listener, Collections.emptySet()); } + + /** + * Gets trained model configs + *

+ * For additional info + * see + * GET Trained Model Configs documentation + * + * @param request The {@link GetTrainedModelsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetTrainedModelsResponse} response object + */ + public GetTrainedModelsResponse getTrainedModels(GetTrainedModelsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getTrainedModels, + options, + GetTrainedModelsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets trained model configs asynchronously and notifies listener upon completion + *

+ * For additional info + * see + * GET Trained Model Configs documentation + * + * @param request The {@link GetTrainedModelsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + * @return cancellable that may be used to cancel the request + */ + public Cancellable getTrainedModelsAsync(GetTrainedModelsRequest request, + RequestOptions options, + ActionListener listener) { + return restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getTrainedModels, + options, + GetTrainedModelsResponse::fromXContent, + listener, + Collections.emptySet()); + } + } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java new file mode 100644 index 00000000000..9234770a97a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -0,0 +1,139 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class GetTrainedModelsRequest implements Validatable { + + public static final String ALLOW_NO_MATCH = "allow_no_match"; + public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; + public static final String DECOMPRESS_DEFINITION = "decompress_definition"; + + private final List ids; + private Boolean allowNoMatch; + private Boolean includeDefinition; + private Boolean decompressDefinition; + private PageParams pageParams; + + /** + * Helper method to create a request that will get ALL TrainedModelConfigs + * @return new {@link GetTrainedModelsRequest} object for the id "_all" + */ + public static GetTrainedModelsRequest getAllTrainedModelConfigsRequest() { + return new GetTrainedModelsRequest("_all"); + } + + public GetTrainedModelsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no trained models. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any trained models + */ + public GetTrainedModelsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + public Boolean getIncludeDefinition() { + return includeDefinition; + } + + /** + * Whether to include the full model definition. + * + * The full model definition can be very large. + * + * @param includeDefinition If {@code true}, the definition is included. + */ + public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) { + this.includeDefinition = includeDefinition; + return this; + } + + public Boolean getDecompressDefinition() { + return decompressDefinition; + } + + /** + * Whether or not to decompress the trained model, or keep it in its compressed string form + * + * @param decompressDefinition If {@code true}, the definition is decompressed. + */ + public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinition) { + this.decompressDefinition = decompressDefinition; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("trained model id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetTrainedModelsRequest other = (GetTrainedModelsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(decompressDefinition, other.decompressDefinition) + && Objects.equals(includeDefinition, other.includeDefinition) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java new file mode 100644 index 00000000000..c83bcd97f77 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java @@ -0,0 +1,86 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class GetTrainedModelsResponse { + + public static final ParseField TRAINED_MODEL_CONFIGS = new ParseField("trained_model_configs"); + public static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_trained_model_configs", + true, + args -> new GetTrainedModelsResponse((List) args[0], (Long) args[1])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelConfig.fromXContent(p), TRAINED_MODEL_CONFIGS); + PARSER.declareLong(constructorArg(), COUNT); + } + + public static GetTrainedModelsResponse fromXContent(final XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List trainedModels; + private final Long count; + + + public GetTrainedModelsResponse(List trainedModels, Long count) { + this.trainedModels = trainedModels; + this.count = count; + } + + public List getTrainedModels() { + return trainedModels; + } + + /** + * @return The total count of the trained models that matched the ID pattern. + */ + public Long getCount() { + return count; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetTrainedModelsResponse other = (GetTrainedModelsResponse) o; + return Objects.equals(this.trainedModels, other.trainedModels) && Objects.equals(this.count, other.count); + } + + @Override + public int hashCode() { + return Objects.hash(trainedModels, count); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 50775cde8a9..23eb01fb3b1 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -77,8 +77,8 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); } - public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null).build(); } private final String modelId; @@ -293,12 +293,12 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - private Builder setCreatedBy(String createdBy) { + public Builder setCreatedBy(String createdBy) { this.createdBy = createdBy; return this; } - private Builder setVersion(Version version) { + public Builder setVersion(Version version) { this.version = version; return this; } @@ -312,7 +312,7 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - private Builder setCreateTime(Instant createTime) { + public Builder setCreateTime(Instant createTime) { this.createTime = createTime; return this; } @@ -347,17 +347,17 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { + public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { this.estimatedHeapMemory = estimatedHeapMemory; return this; } - private Builder setEstimatedOperations(Long estimatedOperations) { + public Builder setEstimatedOperations(Long estimatedOperations) { this.estimatedOperations = estimatedOperations; return this; } - private Builder setLicenseLevel(String licenseLevel) { + public Builder setLicenseLevel(String licenseLevel) { this.licenseLevel = licenseLevel; return this; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index e0fac7bb09a..db59054cdb8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.client.ml.GetJobStatsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetRecordsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.MlInfoRequest; import org.elasticsearch.client.ml.OpenJobRequest; import org.elasticsearch.client.ml.PostCalendarEventRequest; @@ -798,6 +799,31 @@ public class MLRequestConvertersTests extends ESTestCase { } } + public void testGetTrainedModels() { + String modelId1 = randomAlphaOfLength(10); + String modelId2 = randomAlphaOfLength(10); + String modelId3 = randomAlphaOfLength(10); + GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3) + .setAllowNoMatch(false) + .setDecompressDefinition(true) + .setIncludeDefinition(false) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getTrainedModels(getRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3, request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); + assertThat(request.getParameters(), + allOf( + hasEntry("from", "100"), + hasEntry("size", "300"), + hasEntry("allow_no_match", "false"), + hasEntry("decompress_definition", "true"), + hasEntry("include_model_definition", "false") + )); + assertNull(request.getEntity()); + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 910d091c8a0..361b3674550 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -73,6 +73,8 @@ import org.elasticsearch.client.ml.GetJobStatsRequest; import org.elasticsearch.client.ml.GetJobStatsResponse; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.MlInfoRequest; import org.elasticsearch.client.ml.MlInfoResponse; import org.elasticsearch.client.ml.OpenJobRequest; @@ -139,6 +141,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Confu import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelDefinition; +import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; import org.elasticsearch.client.ml.job.config.Detector; @@ -148,11 +153,14 @@ import org.elasticsearch.client.ml.job.config.JobUpdate; import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.process.ModelSnapshot; import org.elasticsearch.client.ml.job.stats.JobStats; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -161,9 +169,11 @@ import org.elasticsearch.search.SearchHit; import org.junit.After; import java.io.IOException; +import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -171,6 +181,7 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import java.util.zip.GZIPOutputStream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; @@ -186,6 +197,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; public class MachineLearningIT extends ESRestHighLevelClientTestCase { @@ -2032,6 +2044,75 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { allOf(greaterThanOrEqualTo(response1.getExpectedMemoryWithDisk()), lessThan(upperBound))); } + public void testGetTrainedModels() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String modelIdPrefix = "get-trained-model-"; + int numberOfModels = 5; + for (int i = 0; i < numberOfModels; ++i) { + String modelId = modelIdPrefix + i; + putTrainedModel(modelId); + } + + { + GetTrainedModelsResponse getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true), + machineLearningClient::getTrainedModels, + machineLearningClient::getTrainedModelsAsync); + + assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1)); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue())); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue()))); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); + + getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true), + machineLearningClient::getTrainedModels, + machineLearningClient::getTrainedModelsAsync); + + assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1)); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(not(nullValue()))); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(nullValue())); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); + + getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false), + machineLearningClient::getTrainedModels, + machineLearningClient::getTrainedModelsAsync); + assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1)); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue())); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(nullValue())); + assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); + + } + { + GetTrainedModelsResponse getTrainedModelsResponse = execute( + GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(), + machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels)); + assertThat(getTrainedModelsResponse.getCount(), equalTo(5L)); + } + { + GetTrainedModelsResponse getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3), + machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(3)); + assertThat(getTrainedModelsResponse.getCount(), equalTo(3L)); + } + { + GetTrainedModelsResponse getTrainedModelsResponse = execute( + new GetTrainedModelsRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)), + machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(2)); + assertThat(getTrainedModelsResponse.getCount(), equalTo(5L)); + assertThat( + getTrainedModelsResponse.getTrainedModels().stream().map(TrainedModelConfig::getModelId).collect(Collectors.toList()), + containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2)); + } + } + public void testPutFilter() throws Exception { String filterId = "filter-job-test"; MlFilter mlFilter = MlFilter.builder(filterId) @@ -2209,6 +2290,60 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { highLevelClient().machineLearning().openJob(new OpenJobRequest(job.getId()), RequestOptions.DEFAULT); } + private void putTrainedModel(String modelId) throws IOException { + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); + highLevelClient().index( + new IndexRequest(".ml-inference-000001") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(modelConfigString(modelId), XContentType.JSON) + .id(modelId), + RequestOptions.DEFAULT); + + highLevelClient().index( + new IndexRequest(".ml-inference-000001") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON) + .id("trained_model_definition_doc-" + modelId + "-0"), + RequestOptions.DEFAULT); + } + + private String compressDefinition(TrainedModelDefinition definition) throws IOException { + BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false); + BytesStreamOutput out = new BytesStreamOutput(); + try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) { + reference.writeTo(compressedOutput); + } + return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + } + + private static String modelConfigString(String modelId) { + return "{\n" + + " \"doc_type\": \"trained_model_config\",\n" + + " \"model_id\": \"" + modelId + "\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model\",\n" + + " \"version\": \"7.6.0\",\n" + + " \"license_level\": \"platinum\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + + " \"created_time\": 0\n" + + "}"; + } + + private static String modelDocString(String compressedDefinition, String modelId) { + return "" + + "{" + + "\"model_id\": \"" + modelId + "\",\n" + + "\"doc_num\": 0,\n" + + "\"doc_type\": \"trained_model_definition_doc\",\n" + + " \"compression_version\": " + 1 + ",\n" + + " \"total_definition_length\": " + compressedDefinition.length() + ",\n" + + " \"definition_length\": " + compressedDefinition.length() + ",\n" + + "\"definition\": \"" + compressedDefinition + "\"\n" + + "}"; + } + private void waitForJobToClose(String jobId) throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 63a397eb0c0..da12420535f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -87,6 +87,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.MlInfoRequest; import org.elasticsearch.client.ml.MlInfoResponse; import org.elasticsearch.client.ml.OpenJobRequest; @@ -154,6 +156,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Confu import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; +import org.elasticsearch.client.ml.inference.TrainedModelDefinition; +import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -174,10 +179,12 @@ import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -188,10 +195,12 @@ import org.elasticsearch.tasks.TaskId; import org.junit.After; import java.io.IOException; +import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.Date; import java.util.HashMap; @@ -200,6 +209,7 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import java.util.zip.GZIPOutputStream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.closeTo; @@ -3516,6 +3526,58 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { } } + public void testGetTrainedModels() throws Exception { + putTrainedModel("my-trained-model"); + RestHighLevelClient client = highLevelClient(); + { + // tag::get-trained-models-request + GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1> + .setPageParams(new PageParams(0, 1)) // <2> + .setIncludeDefinition(false) // <3> + .setDecompressDefinition(false) // <4> + .setAllowNoMatch(true); // <5> + // end::get-trained-models-request + + // tag::get-trained-models-execute + GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT); + // end::get-trained-models-execute + + // tag::get-trained-models-response + List models = response.getTrainedModels(); + // end::get-trained-models-response + + assertThat(models, hasSize(1)); + } + { + GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model"); + + // tag::get-trained-models-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(GetTrainedModelsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-trained-models-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-trained-models-execute-async + client.machineLearning().getTrainedModelsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-trained-models-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testCreateFilter() throws Exception { RestHighLevelClient client = highLevelClient(); { @@ -3878,6 +3940,60 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { return stats.getState(); } + private void putTrainedModel(String modelId) throws IOException { + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); + highLevelClient().index( + new IndexRequest(".ml-inference-000001") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(modelConfigString(modelId), XContentType.JSON) + .id(modelId), + RequestOptions.DEFAULT); + + highLevelClient().index( + new IndexRequest(".ml-inference-000001") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON) + .id("trained_model_definition_doc-" + modelId + "-0"), + RequestOptions.DEFAULT); + } + + private String compressDefinition(TrainedModelDefinition definition) throws IOException { + BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false); + BytesStreamOutput out = new BytesStreamOutput(); + try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) { + reference.writeTo(compressedOutput); + } + return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + } + + private static String modelConfigString(String modelId) { + return "{\n" + + " \"doc_type\": \"trained_model_config\",\n" + + " \"model_id\": \"" + modelId + "\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model for\",\n" + + " \"version\": \"7.6.0\",\n" + + " \"license_level\": \"platinum\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + + " \"created_time\": 0\n" + + "}"; + } + + private static String modelDocString(String compressedDefinition, String modelId) { + return "" + + "{" + + "\"model_id\": \"" + modelId + "\",\n" + + "\"doc_num\": 0,\n" + + "\"doc_type\": \"trained_model_definition_doc\",\n" + + " \"compression_version\": " + 1 + ",\n" + + " \"total_definition_length\": " + compressedDefinition.length() + ",\n" + + " \"definition_length\": " + compressedDefinition.length() + ",\n" + + "\"definition\": \"" + compressedDefinition + "\"\n" + + "}"; + } + private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG = DataFrameAnalyticsConfig.builder() .setId("my-analytics-config") diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java new file mode 100644 index 00000000000..3f26a1ce3b5 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetTrainedModelsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetTrainedModelsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetTrainedModelsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetTrainedModelsRequest(new String[0]).validate().get().getMessage(), + containsString("trained model id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 6d489f32472..51f9692b8b6 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -39,7 +39,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase Constructing a new GET request referencing an existing Trained Model +<2> Set the paging parameters +<3> Indicate if the complete model definition should be included +<4> Should the definition be fully decompressed on GET +<5> Allow empty response if no Trained Models match the provided ID patterns. + If false, an error will be thrown if no Trained Models match the + ID patterns. + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested Trained Model. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index 153a0cf577c..770866a0755 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -301,6 +301,7 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-stop-data-frame-analytics>> * <<{upid}-evaluate-data-frame>> * <<{upid}-estimate-memory-usage>> +* <<{upid}-get-trained-models>> * <<{upid}-put-filter>> * <<{upid}-get-filters>> * <<{upid}-update-filter>> @@ -353,6 +354,7 @@ include::ml/start-data-frame-analytics.asciidoc[] include::ml/stop-data-frame-analytics.asciidoc[] include::ml/evaluate-data-frame.asciidoc[] include::ml/estimate-memory-usage.asciidoc[] +include::ml/get-trained-models.asciidoc[] include::ml/put-filter.asciidoc[] include::ml/get-filters.asciidoc[] include::ml/update-filter.asciidoc[]