diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
index 9df33b410a7..2fc23acd134 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
@@ -58,6 +58,7 @@ import org.elasticsearch.client.ml.GetJobStatsRequest;
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetRecordsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.OpenJobRequest;
import org.elasticsearch.client.ml.PostCalendarEventRequest;
@@ -709,6 +710,38 @@ final class MLRequestConverters {
return request;
}
+ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest) {
+ String endpoint = new EndpointBuilder()
+ .addPathPartAsIs("_ml", "inference")
+ .addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIds()))
+ .build();
+ RequestConverters.Params params = new RequestConverters.Params();
+ if (getTrainedModelsRequest.getPageParams() != null) {
+ PageParams pageParams = getTrainedModelsRequest.getPageParams();
+ if (pageParams.getFrom() != null) {
+ params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString());
+ }
+ if (pageParams.getSize() != null) {
+ params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString());
+ }
+ }
+ if (getTrainedModelsRequest.getAllowNoMatch() != null) {
+ params.putParam(GetTrainedModelsRequest.ALLOW_NO_MATCH,
+ Boolean.toString(getTrainedModelsRequest.getAllowNoMatch()));
+ }
+ if (getTrainedModelsRequest.getDecompressDefinition() != null) {
+ params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
+ Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
+ }
+ if (getTrainedModelsRequest.getIncludeDefinition() != null) {
+ params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
+ Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
+ }
+ Request request = new Request(HttpGet.METHOD_NAME, endpoint);
+ request.addParameters(params.asMap());
+ return request;
+ }
+
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
index 62619303685..2ddc8839f96 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
@@ -73,6 +73,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsResponse;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetRecordsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsResponse;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.MlInfoResponse;
import org.elasticsearch.client.ml.OpenJobRequest;
@@ -2290,4 +2292,48 @@ public final class MachineLearningClient {
listener,
Collections.emptySet());
}
+
+ /**
+ * Gets trained model configs
+ *
+ * For additional info
+ * see
+ * GET Trained Model Configs documentation
+ *
+ * @param request The {@link GetTrainedModelsRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @return {@link GetTrainedModelsResponse} response object
+ */
+ public GetTrainedModelsResponse getTrainedModels(GetTrainedModelsRequest request,
+ RequestOptions options) throws IOException {
+ return restHighLevelClient.performRequestAndParseEntity(request,
+ MLRequestConverters::getTrainedModels,
+ options,
+ GetTrainedModelsResponse::fromXContent,
+ Collections.emptySet());
+ }
+
+ /**
+ * Gets trained model configs asynchronously and notifies listener upon completion
+ *
+ * For additional info
+ * see
+ * GET Trained Model Configs documentation
+ *
+ * @param request The {@link GetTrainedModelsRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @param listener Listener to be notified upon request completion
+ * @return cancellable that may be used to cancel the request
+ */
+ public Cancellable getTrainedModelsAsync(GetTrainedModelsRequest request,
+ RequestOptions options,
+ ActionListener listener) {
+ return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+ MLRequestConverters::getTrainedModels,
+ options,
+ GetTrainedModelsResponse::fromXContent,
+ listener,
+ Collections.emptySet());
+ }
+
}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
new file mode 100644
index 00000000000..9234770a97a
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ValidationException;
+import org.elasticsearch.client.core.PageParams;
+import org.elasticsearch.common.Nullable;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+public class GetTrainedModelsRequest implements Validatable {
+
+ public static final String ALLOW_NO_MATCH = "allow_no_match";
+ public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
+ public static final String DECOMPRESS_DEFINITION = "decompress_definition";
+
+ private final List ids;
+ private Boolean allowNoMatch;
+ private Boolean includeDefinition;
+ private Boolean decompressDefinition;
+ private PageParams pageParams;
+
+ /**
+ * Helper method to create a request that will get ALL TrainedModelConfigs
+ * @return new {@link GetTrainedModelsRequest} object for the id "_all"
+ */
+ public static GetTrainedModelsRequest getAllTrainedModelConfigsRequest() {
+ return new GetTrainedModelsRequest("_all");
+ }
+
+ public GetTrainedModelsRequest(String... ids) {
+ this.ids = Arrays.asList(ids);
+ }
+
+ public List getIds() {
+ return ids;
+ }
+
+ public Boolean getAllowNoMatch() {
+ return allowNoMatch;
+ }
+
+ /**
+ * Whether to ignore if a wildcard expression matches no trained models.
+ *
+ * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all})
+ * does not match any trained models
+ */
+ public GetTrainedModelsRequest setAllowNoMatch(boolean allowNoMatch) {
+ this.allowNoMatch = allowNoMatch;
+ return this;
+ }
+
+ public PageParams getPageParams() {
+ return pageParams;
+ }
+
+ public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) {
+ this.pageParams = pageParams;
+ return this;
+ }
+
+ public Boolean getIncludeDefinition() {
+ return includeDefinition;
+ }
+
+ /**
+ * Whether to include the full model definition.
+ *
+ * The full model definition can be very large.
+ *
+ * @param includeDefinition If {@code true}, the definition is included.
+ */
+ public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
+ this.includeDefinition = includeDefinition;
+ return this;
+ }
+
+ public Boolean getDecompressDefinition() {
+ return decompressDefinition;
+ }
+
+ /**
+ * Whether or not to decompress the trained model, or keep it in its compressed string form
+ *
+ * @param decompressDefinition If {@code true}, the definition is decompressed.
+ */
+ public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinition) {
+ this.decompressDefinition = decompressDefinition;
+ return this;
+ }
+
+ @Override
+ public Optional validate() {
+ if (ids == null || ids.isEmpty()) {
+ return Optional.of(ValidationException.withError("trained model id must not be null"));
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ GetTrainedModelsRequest other = (GetTrainedModelsRequest) o;
+ return Objects.equals(ids, other.ids)
+ && Objects.equals(allowNoMatch, other.allowNoMatch)
+ && Objects.equals(decompressDefinition, other.decompressDefinition)
+ && Objects.equals(includeDefinition, other.includeDefinition)
+ && Objects.equals(pageParams, other.pageParams);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java
new file mode 100644
index 00000000000..c83bcd97f77
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class GetTrainedModelsResponse {
+
+ public static final ParseField TRAINED_MODEL_CONFIGS = new ParseField("trained_model_configs");
+ public static final ParseField COUNT = new ParseField("count");
+
+ @SuppressWarnings("unchecked")
+ static final ConstructingObjectParser PARSER =
+ new ConstructingObjectParser<>(
+ "get_trained_model_configs",
+ true,
+ args -> new GetTrainedModelsResponse((List) args[0], (Long) args[1]));
+
+ static {
+ PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelConfig.fromXContent(p), TRAINED_MODEL_CONFIGS);
+ PARSER.declareLong(constructorArg(), COUNT);
+ }
+
+ public static GetTrainedModelsResponse fromXContent(final XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
+
+ private final List trainedModels;
+ private final Long count;
+
+
+ public GetTrainedModelsResponse(List trainedModels, Long count) {
+ this.trainedModels = trainedModels;
+ this.count = count;
+ }
+
+ public List getTrainedModels() {
+ return trainedModels;
+ }
+
+ /**
+ * @return The total count of the trained models that matched the ID pattern.
+ */
+ public Long getCount() {
+ return count;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ GetTrainedModelsResponse other = (GetTrainedModelsResponse) o;
+ return Objects.equals(this.trainedModels, other.trainedModels) && Objects.equals(this.count, other.count);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModels, count);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
index 50775cde8a9..23eb01fb3b1 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
@@ -77,8 +77,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
}
- public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
- return PARSER.parse(parser, null);
+ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
+ return PARSER.parse(parser, null).build();
}
private final String modelId;
@@ -293,12 +293,12 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- private Builder setCreatedBy(String createdBy) {
+ public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
- private Builder setVersion(Version version) {
+ public Builder setVersion(Version version) {
this.version = version;
return this;
}
@@ -312,7 +312,7 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- private Builder setCreateTime(Instant createTime) {
+ public Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
@@ -347,17 +347,17 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
- private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
+ public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
- private Builder setEstimatedOperations(Long estimatedOperations) {
+ public Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
- private Builder setLicenseLevel(String licenseLevel) {
+ public Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
index e0fac7bb09a..db59054cdb8 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
@@ -56,6 +56,7 @@ import org.elasticsearch.client.ml.GetJobStatsRequest;
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetRecordsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.OpenJobRequest;
import org.elasticsearch.client.ml.PostCalendarEventRequest;
@@ -798,6 +799,31 @@ public class MLRequestConvertersTests extends ESTestCase {
}
}
+ public void testGetTrainedModels() {
+ String modelId1 = randomAlphaOfLength(10);
+ String modelId2 = randomAlphaOfLength(10);
+ String modelId3 = randomAlphaOfLength(10);
+ GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
+ .setAllowNoMatch(false)
+ .setDecompressDefinition(true)
+ .setIncludeDefinition(false)
+ .setPageParams(new PageParams(100, 300));
+
+ Request request = MLRequestConverters.getTrainedModels(getRequest);
+ assertEquals(HttpGet.METHOD_NAME, request.getMethod());
+ assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3, request.getEndpoint());
+ assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false")));
+ assertThat(request.getParameters(),
+ allOf(
+ hasEntry("from", "100"),
+ hasEntry("size", "300"),
+ hasEntry("allow_no_match", "false"),
+ hasEntry("decompress_definition", "true"),
+ hasEntry("include_model_definition", "false")
+ ));
+ assertNull(request.getEntity());
+ }
+
public void testPutFilter() throws IOException {
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
index 910d091c8a0..361b3674550 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
@@ -73,6 +73,8 @@ import org.elasticsearch.client.ml.GetJobStatsRequest;
import org.elasticsearch.client.ml.GetJobStatsResponse;
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsResponse;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.MlInfoResponse;
import org.elasticsearch.client.ml.OpenJobRequest;
@@ -139,6 +141,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Confu
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.DataDescription;
import org.elasticsearch.client.ml.job.config.Detector;
@@ -148,11 +153,14 @@ import org.elasticsearch.client.ml.job.config.JobUpdate;
import org.elasticsearch.client.ml.job.config.MlFilter;
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
import org.elasticsearch.client.ml.job.stats.JobStats;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -161,9 +169,11 @@ import org.elasticsearch.search.SearchHit;
import org.junit.After;
import java.io.IOException;
+import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -171,6 +181,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
@@ -186,6 +197,7 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
public class MachineLearningIT extends ESRestHighLevelClientTestCase {
@@ -2032,6 +2044,75 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
allOf(greaterThanOrEqualTo(response1.getExpectedMemoryWithDisk()), lessThan(upperBound)));
}
+ public void testGetTrainedModels() throws Exception {
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+ String modelIdPrefix = "get-trained-model-";
+ int numberOfModels = 5;
+ for (int i = 0; i < numberOfModels; ++i) {
+ String modelId = modelIdPrefix + i;
+ putTrainedModel(modelId);
+ }
+
+ {
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
+ machineLearningClient::getTrainedModels,
+ machineLearningClient::getTrainedModelsAsync);
+
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
+
+ getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
+ machineLearningClient::getTrainedModels,
+ machineLearningClient::getTrainedModelsAsync);
+
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(not(nullValue())));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
+
+ getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
+ machineLearningClient::getTrainedModels,
+ machineLearningClient::getTrainedModelsAsync);
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
+
+ }
+ {
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(),
+ machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels));
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(5L));
+ }
+ {
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3),
+ machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(3));
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(3L));
+ }
+ {
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)),
+ machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(2));
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(5L));
+ assertThat(
+ getTrainedModelsResponse.getTrainedModels().stream().map(TrainedModelConfig::getModelId).collect(Collectors.toList()),
+ containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2));
+ }
+ }
+
public void testPutFilter() throws Exception {
String filterId = "filter-job-test";
MlFilter mlFilter = MlFilter.builder(filterId)
@@ -2209,6 +2290,60 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
highLevelClient().machineLearning().openJob(new OpenJobRequest(job.getId()), RequestOptions.DEFAULT);
}
+ private void putTrainedModel(String modelId) throws IOException {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ highLevelClient().index(
+ new IndexRequest(".ml-inference-000001")
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+ .source(modelConfigString(modelId), XContentType.JSON)
+ .id(modelId),
+ RequestOptions.DEFAULT);
+
+ highLevelClient().index(
+ new IndexRequest(".ml-inference-000001")
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+ .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
+ .id("trained_model_definition_doc-" + modelId + "-0"),
+ RequestOptions.DEFAULT);
+ }
+
+ private String compressDefinition(TrainedModelDefinition definition) throws IOException {
+ BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
+ BytesStreamOutput out = new BytesStreamOutput();
+ try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
+ reference.writeTo(compressedOutput);
+ }
+ return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
+ }
+
+ private static String modelConfigString(String modelId) {
+ return "{\n" +
+ " \"doc_type\": \"trained_model_config\",\n" +
+ " \"model_id\": \"" + modelId + "\",\n" +
+ " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
+ " \"description\": \"test model\",\n" +
+ " \"version\": \"7.6.0\",\n" +
+ " \"license_level\": \"platinum\",\n" +
+ " \"created_by\": \"ml_test\",\n" +
+ " \"estimated_heap_memory_usage_bytes\": 0," +
+ " \"estimated_operations\": 0," +
+ " \"created_time\": 0\n" +
+ "}";
+ }
+
+ private static String modelDocString(String compressedDefinition, String modelId) {
+ return "" +
+ "{" +
+ "\"model_id\": \"" + modelId + "\",\n" +
+ "\"doc_num\": 0,\n" +
+ "\"doc_type\": \"trained_model_definition_doc\",\n" +
+ " \"compression_version\": " + 1 + ",\n" +
+ " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
+ " \"definition_length\": " + compressedDefinition.length() + ",\n" +
+ "\"definition\": \"" + compressedDefinition + "\"\n" +
+ "}";
+ }
+
private void waitForJobToClose(String jobId) throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 63a397eb0c0..da12420535f 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -87,6 +87,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsResponse;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetRecordsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsResponse;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.MlInfoResponse;
import org.elasticsearch.client.ml.OpenJobRequest;
@@ -154,6 +156,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Confu
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
import org.elasticsearch.client.ml.job.config.DataDescription;
@@ -174,10 +179,12 @@ import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -188,10 +195,12 @@ import org.elasticsearch.tasks.TaskId;
import org.junit.After;
import java.io.IOException;
+import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
+import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
@@ -200,6 +209,7 @@ import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
@@ -3516,6 +3526,58 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
}
}
+ public void testGetTrainedModels() throws Exception {
+ putTrainedModel("my-trained-model");
+ RestHighLevelClient client = highLevelClient();
+ {
+ // tag::get-trained-models-request
+ GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
+ .setPageParams(new PageParams(0, 1)) // <2>
+ .setIncludeDefinition(false) // <3>
+ .setDecompressDefinition(false) // <4>
+ .setAllowNoMatch(true); // <5>
+ // end::get-trained-models-request
+
+ // tag::get-trained-models-execute
+ GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);
+ // end::get-trained-models-execute
+
+ // tag::get-trained-models-response
+ List models = response.getTrainedModels();
+ // end::get-trained-models-response
+
+ assertThat(models, hasSize(1));
+ }
+ {
+ GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model");
+
+ // tag::get-trained-models-execute-listener
+ ActionListener listener = new ActionListener() {
+ @Override
+ public void onResponse(GetTrainedModelsResponse response) {
+ // <1>
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ // <2>
+ }
+ };
+ // end::get-trained-models-execute-listener
+
+ // Replace the empty listener by a blocking listener in test
+ CountDownLatch latch = new CountDownLatch(1);
+ listener = new LatchedActionListener<>(listener, latch);
+
+ // tag::get-trained-models-execute-async
+ client.machineLearning().getTrainedModelsAsync(request, RequestOptions.DEFAULT, listener); // <1>
+ // end::get-trained-models-execute-async
+
+ assertTrue(latch.await(30L, TimeUnit.SECONDS));
+ }
+ }
+
+
public void testCreateFilter() throws Exception {
RestHighLevelClient client = highLevelClient();
{
@@ -3878,6 +3940,60 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
return stats.getState();
}
+ private void putTrainedModel(String modelId) throws IOException {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ highLevelClient().index(
+ new IndexRequest(".ml-inference-000001")
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+ .source(modelConfigString(modelId), XContentType.JSON)
+ .id(modelId),
+ RequestOptions.DEFAULT);
+
+ highLevelClient().index(
+ new IndexRequest(".ml-inference-000001")
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+ .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
+ .id("trained_model_definition_doc-" + modelId + "-0"),
+ RequestOptions.DEFAULT);
+ }
+
+ private String compressDefinition(TrainedModelDefinition definition) throws IOException {
+ BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
+ BytesStreamOutput out = new BytesStreamOutput();
+ try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
+ reference.writeTo(compressedOutput);
+ }
+ return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
+ }
+
+ private static String modelConfigString(String modelId) {
+ return "{\n" +
+ " \"doc_type\": \"trained_model_config\",\n" +
+ " \"model_id\": \"" + modelId + "\",\n" +
+ " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
+ " \"description\": \"test model for\",\n" +
+ " \"version\": \"7.6.0\",\n" +
+ " \"license_level\": \"platinum\",\n" +
+ " \"created_by\": \"ml_test\",\n" +
+ " \"estimated_heap_memory_usage_bytes\": 0," +
+ " \"estimated_operations\": 0," +
+ " \"created_time\": 0\n" +
+ "}";
+ }
+
+ private static String modelDocString(String compressedDefinition, String modelId) {
+ return "" +
+ "{" +
+ "\"model_id\": \"" + modelId + "\",\n" +
+ "\"doc_num\": 0,\n" +
+ "\"doc_type\": \"trained_model_definition_doc\",\n" +
+ " \"compression_version\": " + 1 + ",\n" +
+ " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
+ " \"definition_length\": " + compressedDefinition.length() + ",\n" +
+ "\"definition\": \"" + compressedDefinition + "\"\n" +
+ "}";
+ }
+
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =
DataFrameAnalyticsConfig.builder()
.setId("my-analytics-config")
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java
new file mode 100644
index 00000000000..3f26a1ce3b5
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsRequestTests.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Optional;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class GetTrainedModelsRequestTests extends ESTestCase {
+
+ public void testValidate_Ok() {
+ assertEquals(Optional.empty(), new GetTrainedModelsRequest("valid-id").validate());
+ assertEquals(Optional.empty(), new GetTrainedModelsRequest("").validate());
+ }
+
+ public void testValidate_Failure() {
+ assertThat(new GetTrainedModelsRequest(new String[0]).validate().get().getMessage(),
+ containsString("trained model id must not be null"));
+ }
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
index 6d489f32472..51f9692b8b6 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
@@ -39,7 +39,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase Constructing a new GET request referencing an existing Trained Model
+<2> Set the paging parameters
+<3> Indicate if the complete model definition should be included
+<4> Should the definition be fully decompressed on GET
+<5> Allow empty response if no Trained Models match the provided ID patterns.
+ If false, an error will be thrown if no Trained Models match the
+ ID patterns.
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Response
+
+The returned +{response}+ contains the requested Trained Model.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------
diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc
index 153a0cf577c..770866a0755 100644
--- a/docs/java-rest/high-level/supported-apis.asciidoc
+++ b/docs/java-rest/high-level/supported-apis.asciidoc
@@ -301,6 +301,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
* <<{upid}-stop-data-frame-analytics>>
* <<{upid}-evaluate-data-frame>>
* <<{upid}-estimate-memory-usage>>
+* <<{upid}-get-trained-models>>
* <<{upid}-put-filter>>
* <<{upid}-get-filters>>
* <<{upid}-update-filter>>
@@ -353,6 +354,7 @@ include::ml/start-data-frame-analytics.asciidoc[]
include::ml/stop-data-frame-analytics.asciidoc[]
include::ml/evaluate-data-frame.asciidoc[]
include::ml/estimate-memory-usage.asciidoc[]
+include::ml/get-trained-models.asciidoc[]
include::ml/put-filter.asciidoc[]
include::ml/get-filters.asciidoc[]
include::ml/update-filter.asciidoc[]