* [ML][Inference] PUT API (#50852) This adds the `PUT` API for creating trained models that support our format. This includes * HLRC change for the API * API creation * Validations of model format and call * fixing backport
This commit is contained in:
parent
456de59698
commit
fa116a6d26
|
@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
|
|||
import org.elasticsearch.client.ml.PutDatafeedRequest;
|
||||
import org.elasticsearch.client.ml.PutFilterRequest;
|
||||
import org.elasticsearch.client.ml.PutJobRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
||||
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
|
||||
|
@ -792,6 +793,16 @@ final class MLRequestConverters {
|
|||
return new Request(HttpDelete.METHOD_NAME, endpoint);
|
||||
}
|
||||
|
||||
static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
|
||||
String endpoint = new EndpointBuilder()
|
||||
.addPathPartAsIs("_ml", "inference")
|
||||
.addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
|
||||
.build();
|
||||
Request request = new Request(HttpPut.METHOD_NAME, endpoint);
|
||||
request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
|
||||
return request;
|
||||
}
|
||||
|
||||
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
|
||||
String endpoint = new EndpointBuilder()
|
||||
.addPathPartAsIs("_ml")
|
||||
|
|
|
@ -100,6 +100,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
|
|||
import org.elasticsearch.client.ml.PutFilterResponse;
|
||||
import org.elasticsearch.client.ml.PutJobRequest;
|
||||
import org.elasticsearch.client.ml.PutJobResponse;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelResponse;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
|
||||
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
||||
|
@ -2340,6 +2342,48 @@ public final class MachineLearningClient {
|
|||
Collections.emptySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* Put trained model config
|
||||
* <p>
|
||||
* For additional info
|
||||
* see <a href="TODO">
|
||||
* PUT Trained Model Config documentation</a>
|
||||
*
|
||||
* @param request The {@link PutTrainedModelRequest}
|
||||
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
|
||||
* @return {@link PutTrainedModelResponse} response object
|
||||
*/
|
||||
public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
|
||||
return restHighLevelClient.performRequestAndParseEntity(request,
|
||||
MLRequestConverters::putTrainedModel,
|
||||
options,
|
||||
PutTrainedModelResponse::fromXContent,
|
||||
Collections.emptySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* Put trained model config asynchronously and notifies listener upon completion
|
||||
* <p>
|
||||
* For additional info
|
||||
* see <a href="TODO">
|
||||
* PUT Trained Model Config documentation</a>
|
||||
*
|
||||
* @param request The {@link PutTrainedModelRequest}
|
||||
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
|
||||
* @param listener Listener to be notified upon request completion
|
||||
* @return cancellable that may be used to cancel the request
|
||||
*/
|
||||
public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
|
||||
RequestOptions options,
|
||||
ActionListener<PutTrainedModelResponse> listener) {
|
||||
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
|
||||
MLRequestConverters::putTrainedModel,
|
||||
options,
|
||||
PutTrainedModelResponse::fromXContent,
|
||||
listener,
|
||||
Collections.emptySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets trained model stats
|
||||
* <p>
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml;
|
||||
|
||||
import org.elasticsearch.client.Validatable;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class PutTrainedModelRequest implements Validatable, ToXContentObject {
|
||||
|
||||
private final TrainedModelConfig config;
|
||||
|
||||
public PutTrainedModelRequest(TrainedModelConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public TrainedModelConfig getTrainedModelConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
return config.toXContent(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PutTrainedModelRequest request = (PutTrainedModelRequest) o;
|
||||
return Objects.equals(config, request.config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return Strings.toString(config);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class PutTrainedModelResponse implements ToXContentObject {
|
||||
|
||||
private final TrainedModelConfig trainedModelConfig;
|
||||
|
||||
public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
|
||||
return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
|
||||
}
|
||||
|
||||
public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
|
||||
this.trainedModelConfig = trainedModelConfig;
|
||||
}
|
||||
|
||||
public TrainedModelConfig getResponse() {
|
||||
return trainedModelConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return trainedModelConfig.toXContent(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PutTrainedModelResponse response = (PutTrainedModelResponse) o;
|
||||
return Objects.equals(trainedModelConfig, response.trainedModelConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModelConfig);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.CheckedFunction;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.Streams;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
|
||||
/**
|
||||
* Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
|
||||
*/
|
||||
public final class InferenceToXContentCompressor {
|
||||
private static final int BUFFER_SIZE = 4096;
|
||||
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
|
||||
|
||||
private InferenceToXContentCompressor() {}
|
||||
|
||||
public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
|
||||
BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
|
||||
return deflate(reference);
|
||||
}
|
||||
|
||||
public static <T> T inflate(String compressedString,
|
||||
CheckedFunction<XContentParser, T, IOException> parserFunction,
|
||||
NamedXContentRegistry xContentRegistry) throws IOException {
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
inflate(compressedString, MAX_INFLATED_BYTES),
|
||||
XContentType.JSON)) {
|
||||
return parserFunction.apply(parser);
|
||||
}
|
||||
}
|
||||
|
||||
static BytesReference inflate(String compressedString, long streamSize) throws IOException {
|
||||
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
|
||||
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
|
||||
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
|
||||
return Streams.readFully(inflateStream);
|
||||
}
|
||||
|
||||
private static String deflate(BytesReference reference) throws IOException {
|
||||
BytesStreamOutput out = new BytesStreamOutput();
|
||||
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
|
||||
reference.writeTo(compressedOutput);
|
||||
}
|
||||
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* This is a pared down bounded input stream.
|
||||
* Only read is specifically enforced.
|
||||
*/
|
||||
final class SimpleBoundedInputStream extends InputStream {
|
||||
|
||||
private final InputStream in;
|
||||
private final long maxBytes;
|
||||
private long numBytes;
|
||||
|
||||
SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
|
||||
this.in = Objects.requireNonNull(inputStream, "inputStream");
|
||||
if (maxBytes < 0) {
|
||||
throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
|
||||
}
|
||||
this.maxBytes = maxBytes;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
|
||||
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
|
||||
* @throws IOException on failure
|
||||
*/
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
// We have reached the maximum, signal stream completion.
|
||||
if (numBytes >= maxBytes) {
|
||||
return -1;
|
||||
}
|
||||
numBytes++;
|
||||
return in.read();
|
||||
}
|
||||
|
||||
/**
|
||||
* Delegates `close` to the wrapped InputStream
|
||||
* @throws IOException on failure
|
||||
*/
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
in.close();
|
||||
}
|
||||
}
|
|
@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -111,7 +112,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
this.modelId = modelId;
|
||||
this.createdBy = createdBy;
|
||||
this.version = version;
|
||||
this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
|
||||
this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
|
||||
this.definition = definition;
|
||||
this.compressedDefinition = compressedDefinition;
|
||||
this.description = description;
|
||||
|
@ -293,12 +294,12 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCreatedBy(String createdBy) {
|
||||
private Builder setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setVersion(Version version) {
|
||||
private Builder setVersion(Version version) {
|
||||
this.version = version;
|
||||
return this;
|
||||
}
|
||||
|
@ -312,7 +313,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCreateTime(Instant createTime) {
|
||||
private Builder setCreateTime(Instant createTime) {
|
||||
this.createTime = createTime;
|
||||
return this;
|
||||
}
|
||||
|
@ -322,6 +323,10 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setTags(String... tags) {
|
||||
return setTags(Arrays.asList(tags));
|
||||
}
|
||||
|
||||
public Builder setMetadata(Map<String, Object> metadata) {
|
||||
this.metadata = metadata;
|
||||
return this;
|
||||
|
@ -347,17 +352,17 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
|
||||
private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedOperations(Long estimatedOperations) {
|
||||
private Builder setEstimatedOperations(Long estimatedOperations) {
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLicenseLevel(String licenseLevel) {
|
||||
private Builder setLicenseLevel(String licenseLevel) {
|
||||
this.licenseLevel = licenseLevel;
|
||||
return this;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -48,6 +49,10 @@ public class TrainedModelInput implements ToXContentObject {
|
|||
this.fieldNames = fieldNames;
|
||||
}
|
||||
|
||||
public TrainedModelInput(String... fieldNames) {
|
||||
this(Arrays.asList(fieldNames));
|
||||
}
|
||||
|
||||
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
|
||||
return PARSER.parse(parser, null);
|
||||
}
|
||||
|
|
|
@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
|
|||
import org.elasticsearch.client.ml.PutDatafeedRequest;
|
||||
import org.elasticsearch.client.ml.PutFilterRequest;
|
||||
import org.elasticsearch.client.ml.PutJobRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
||||
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
|
||||
|
@ -91,6 +92,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
|||
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.client.ml.job.config.Detector;
|
||||
import org.elasticsearch.client.ml.job.config.Job;
|
||||
|
@ -874,6 +878,20 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||
assertNull(request.getEntity());
|
||||
}
|
||||
|
||||
public void testPutTrainedModel() throws IOException {
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig();
|
||||
PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig);
|
||||
|
||||
Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest);
|
||||
|
||||
assertEquals(HttpPut.METHOD_NAME, request.getMethod());
|
||||
assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId()));
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
|
||||
TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build();
|
||||
assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig));
|
||||
}
|
||||
}
|
||||
|
||||
public void testPutFilter() throws IOException {
|
||||
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
|
||||
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
|
||||
|
@ -1046,6 +1064,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
|
|
|
@ -101,6 +101,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
|
|||
import org.elasticsearch.client.ml.PutFilterResponse;
|
||||
import org.elasticsearch.client.ml.PutJobRequest;
|
||||
import org.elasticsearch.client.ml.PutJobResponse;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelResponse;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
|
||||
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
||||
|
@ -146,9 +148,12 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
|
|||
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
|
||||
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
|
||||
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
||||
import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
|
||||
|
@ -162,14 +167,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter;
|
|||
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
|
||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
|
@ -178,11 +181,9 @@ import org.elasticsearch.search.SearchHit;
|
|||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -190,7 +191,6 @@ import java.util.Locale;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
|
@ -2222,6 +2222,50 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testPutTrainedModel() throws Exception {
|
||||
String modelId = "test-put-trained-model";
|
||||
String modelIdCompressed = "test-put-trained-model-compressed-definition";
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
|
||||
.setDescription("test model")
|
||||
.build();
|
||||
PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
|
||||
machineLearningClient::putTrainedModel,
|
||||
machineLearningClient::putTrainedModelAsync);
|
||||
TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
|
||||
assertThat(createdModel.getModelId(), equalTo(modelId));
|
||||
|
||||
definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||
trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
|
||||
.setModelId(modelIdCompressed)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
|
||||
.setDescription("test model")
|
||||
.build();
|
||||
putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
|
||||
machineLearningClient::putTrainedModel,
|
||||
machineLearningClient::putTrainedModelAsync);
|
||||
createdModel = putTrainedModelResponse.getResponse();
|
||||
assertThat(createdModel.getModelId(), equalTo(modelIdCompressed));
|
||||
|
||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||
new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true),
|
||||
machineLearningClient::getTrainedModels,
|
||||
machineLearningClient::getTrainedModelsAsync);
|
||||
|
||||
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
|
||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed));
|
||||
}
|
||||
|
||||
public void testGetTrainedModelsStats() throws Exception {
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
String modelIdPrefix = "a-get-trained-model-stats-";
|
||||
|
@ -2504,56 +2548,13 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
private void putTrainedModel(String modelId) throws IOException {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||
highLevelClient().index(
|
||||
new IndexRequest(".ml-inference-000001")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.source(modelConfigString(modelId), XContentType.JSON)
|
||||
.id(modelId),
|
||||
RequestOptions.DEFAULT);
|
||||
|
||||
highLevelClient().index(
|
||||
new IndexRequest(".ml-inference-000001")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
|
||||
.id("trained_model_definition_doc-" + modelId + "-0"),
|
||||
RequestOptions.DEFAULT);
|
||||
}
|
||||
|
||||
private String compressDefinition(TrainedModelDefinition definition) throws IOException {
|
||||
BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
|
||||
BytesStreamOutput out = new BytesStreamOutput();
|
||||
try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
|
||||
reference.writeTo(compressedOutput);
|
||||
}
|
||||
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
private static String modelConfigString(String modelId) {
|
||||
return "{\n" +
|
||||
" \"doc_type\": \"trained_model_config\",\n" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0\n" +
|
||||
"}";
|
||||
}
|
||||
|
||||
private static String modelDocString(String compressedDefinition, String modelId) {
|
||||
return "" +
|
||||
"{" +
|
||||
"\"model_id\": \"" + modelId + "\",\n" +
|
||||
"\"doc_num\": 0,\n" +
|
||||
"\"doc_type\": \"trained_model_definition_doc\",\n" +
|
||||
" \"compression_version\": " + 1 + ",\n" +
|
||||
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
"\"definition\": \"" + compressedDefinition + "\"\n" +
|
||||
"}";
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
|
||||
.setDescription("test model")
|
||||
.build();
|
||||
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
|
||||
}
|
||||
|
||||
private void waitForJobToClose(String jobId) throws Exception {
|
||||
|
@ -2798,4 +2799,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT);
|
||||
assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
|
|||
import org.elasticsearch.client.ml.PutFilterResponse;
|
||||
import org.elasticsearch.client.ml.PutJobRequest;
|
||||
import org.elasticsearch.client.ml.PutJobResponse;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
||||
import org.elasticsearch.client.ml.PutTrainedModelResponse;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
||||
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
|
||||
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
||||
|
@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
|
|||
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
|
||||
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
|
||||
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
||||
import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
|
||||
import org.elasticsearch.client.ml.job.config.DataDescription;
|
||||
|
@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer;
|
|||
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
|
@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId;
|
|||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
|
@ -216,7 +219,6 @@ import java.util.Map;
|
|||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
|
@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testPutTrainedModel() throws Exception {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||
// tag::put-trained-model-config
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition) // <1>
|
||||
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
|
||||
.setModelId("my-new-trained-model") // <3>
|
||||
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
|
||||
.setDescription("test model") // <5>
|
||||
.setMetadata(new HashMap<>()) // <6>
|
||||
.setTags("my_regression_models") // <7>
|
||||
.build();
|
||||
// end::put-trained-model-config
|
||||
|
||||
trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId("my-new-trained-model")
|
||||
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
|
||||
.setDescription("test model")
|
||||
.setMetadata(new HashMap<>())
|
||||
.setTags("my_regression_models")
|
||||
.build();
|
||||
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
{
|
||||
// tag::put-trained-model-request
|
||||
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
|
||||
// end::put-trained-model-request
|
||||
|
||||
// tag::put-trained-model-execute
|
||||
PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
|
||||
// end::put-trained-model-execute
|
||||
|
||||
// tag::put-trained-model-response
|
||||
TrainedModelConfig model = response.getResponse();
|
||||
// end::put-trained-model-response
|
||||
|
||||
assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
|
||||
highLevelClient().machineLearning()
|
||||
.deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
|
||||
}
|
||||
{
|
||||
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
|
||||
|
||||
// tag::put-trained-model-execute-listener
|
||||
ActionListener<PutTrainedModelResponse> listener = new ActionListener<PutTrainedModelResponse>() {
|
||||
@Override
|
||||
public void onResponse(PutTrainedModelResponse response) {
|
||||
// <1>
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
// <2>
|
||||
}
|
||||
};
|
||||
// end::put-trained-model-execute-listener
|
||||
|
||||
// Replace the empty listener by a blocking listener in test
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
listener = new LatchedActionListener<>(listener, latch);
|
||||
|
||||
// tag::put-trained-model-execute-async
|
||||
client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
|
||||
// end::put-trained-model-execute-async
|
||||
|
||||
assertTrue(latch.await(30L, TimeUnit.SECONDS));
|
||||
|
||||
highLevelClient().machineLearning()
|
||||
.deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetTrainedModelsStats() throws Exception {
|
||||
putTrainedModel("my-trained-model");
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
|
@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
}
|
||||
|
||||
private void putTrainedModel(String modelId) throws IOException {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
highLevelClient().index(
|
||||
new IndexRequest(".ml-inference-000001")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.source(modelConfigString(modelId), XContentType.JSON)
|
||||
.id(modelId),
|
||||
RequestOptions.DEFAULT);
|
||||
|
||||
highLevelClient().index(
|
||||
new IndexRequest(".ml-inference-000001")
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
|
||||
.id("trained_model_definition_doc-" + modelId + "-0"),
|
||||
RequestOptions.DEFAULT);
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
|
||||
.setDescription("test model")
|
||||
.build();
|
||||
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
|
||||
}
|
||||
|
||||
private String compressDefinition(TrainedModelDefinition definition) throws IOException {
|
||||
BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
|
||||
BytesStreamOutput out = new BytesStreamOutput();
|
||||
try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
|
||||
reference.writeTo(compressedOutput);
|
||||
}
|
||||
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
private static String modelConfigString(String modelId) {
|
||||
return "{\n" +
|
||||
" \"doc_type\": \"trained_model_config\",\n" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0\n" +
|
||||
"}";
|
||||
}
|
||||
|
||||
private static String modelDocString(String compressedDefinition, String modelId) {
|
||||
return "" +
|
||||
"{" +
|
||||
"\"model_id\": \"" + modelId + "\",\n" +
|
||||
"\"doc_num\": 0,\n" +
|
||||
"\"doc_type\": \"trained_model_definition_doc\",\n" +
|
||||
" \"compression_version\": " + 1 + ",\n" +
|
||||
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
"\"definition\": \"" + compressedDefinition + "\"\n" +
|
||||
"}";
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase<PutTrainedModelRequest> {
|
||||
|
||||
@Override
|
||||
protected PutTrainedModelRequest createTestInstance() {
|
||||
return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException {
|
||||
return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase<PutTrainedModelResponse> {
|
||||
|
||||
@Override
|
||||
protected PutTrainedModelResponse createTestInstance() {
|
||||
return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException {
|
||||
return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class InferenceToXContentCompressorTests extends ESTestCase {
|
||||
|
||||
public void testInflateAndDeflate() throws IOException {
|
||||
for(int i = 0; i < 10; i++) {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
|
||||
TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
|
||||
parser -> TrainedModelDefinition.fromXContent(parser).build(),
|
||||
xContentRegistry());
|
||||
|
||||
// Did we inflate to the same object?
|
||||
assertThat(inflatedDefinition, equalTo(definition));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInflateTooLargeStream() throws IOException {
|
||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
|
||||
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
|
||||
assertThat(inflatedBytes.length(), equalTo(10));
|
||||
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
LoggingDeprecationHandler.INSTANCE,
|
||||
inflatedBytes,
|
||||
XContentType.JSON)) {
|
||||
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInflateGarbage() {
|
||||
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
}
|
|
@ -37,6 +37,24 @@ import java.util.stream.Stream;
|
|||
|
||||
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
|
||||
|
||||
public static TrainedModelConfig createTestTrainedModelConfig() {
|
||||
return new TrainedModelConfig(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
Version.CURRENT,
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomFrom("platinum", "basic"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
|
||||
return TrainedModelConfig.fromXContent(parser);
|
||||
|
@ -54,22 +72,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
|||
|
||||
@Override
|
||||
protected TrainedModelConfig createTestInstance() {
|
||||
return new TrainedModelConfig(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
Version.CURRENT,
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomFrom("platinum", "basic"));
|
||||
|
||||
return createTestTrainedModelConfig();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -67,15 +67,17 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||
.collect(Collectors.toList());
|
||||
int numberOfModels = randomIntBetween(1, 10);
|
||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
|
||||
.limit(numberOfFeatures)
|
||||
.limit(numberOfModels)
|
||||
.collect(Collectors.toList());
|
||||
OutputAggregator outputAggregator = null;
|
||||
if (randomBoolean()) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights));
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
|
||||
new LogisticRegression(weights)));
|
||||
if (targetType.equals(TargetType.REGRESSION)) {
|
||||
possibleAggregators.add(new WeightedSum(weights));
|
||||
}
|
||||
OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
return new Ensemble(featureNames,
|
||||
|
|
|
@ -84,7 +84,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||
childNodes = nextNodes;
|
||||
}
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
return builder.setClassificationLabels(categoryLabels)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
--
|
||||
:api: put-trained-model
|
||||
:request: PutTrainedModelRequest
|
||||
:response: PutTrainedModelResponse
|
||||
--
|
||||
[role="xpack"]
|
||||
[id="{upid}-{api}"]
|
||||
=== Put Trained Model API
|
||||
|
||||
Creates a new trained model for inference.
|
||||
The API accepts a +{request}+ object as a request and returns a +{response}+.
|
||||
|
||||
[id="{upid}-{api}-request"]
|
||||
==== Put Trained Model request
|
||||
|
||||
A +{request}+ requires the following argument:
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-request]
|
||||
--------------------------------------------------
|
||||
<1> The configuration of the {infer} Trained Model to create
|
||||
|
||||
[id="{upid}-{api}-config"]
|
||||
==== Trained Model configuration
|
||||
|
||||
The `TrainedModelConfig` object contains all the details about the trained model
|
||||
configuration and contains the following arguments:
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-config]
|
||||
--------------------------------------------------
|
||||
<1> The {infer} definition for the model
|
||||
<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport.
|
||||
Do not supply both the compressed and uncompressed definitions.
|
||||
<3> The unique model id
|
||||
<4> The input field names for the model definition
|
||||
<5> Optionally, a human-readable description
|
||||
<6> Optionally, an object map contain metadata about the model
|
||||
<7> Optionally, an array of tags to organize the model
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
[id="{upid}-{api}-response"]
|
||||
==== Response
|
||||
|
||||
The returned +{response}+ contains the newly created trained model.
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-response]
|
||||
--------------------------------------------------
|
|
@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
|
|||
* <<{upid}-evaluate-data-frame>>
|
||||
* <<{upid}-explain-data-frame-analytics>>
|
||||
* <<{upid}-get-trained-models>>
|
||||
* <<{upid}-put-trained-model>>
|
||||
* <<{upid}-get-trained-models-stats>>
|
||||
* <<{upid}-delete-trained-model>>
|
||||
* <<{upid}-put-filter>>
|
||||
|
@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
|
|||
include::ml/evaluate-data-frame.asciidoc[]
|
||||
include::ml/explain-data-frame-analytics.asciidoc[]
|
||||
include::ml/get-trained-models.asciidoc[]
|
||||
include::ml/put-trained-model.asciidoc[]
|
||||
include::ml/get-trained-models-stats.asciidoc[]
|
||||
include::ml/delete-trained-model.asciidoc[]
|
||||
include::ml/put-filter.asciidoc[]
|
||||
|
|
|
@ -126,6 +126,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
||||
|
@ -381,6 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
GetTrainedModelsAction.INSTANCE,
|
||||
DeleteTrainedModelAction.INSTANCE,
|
||||
GetTrainedModelsStatsAction.INSTANCE,
|
||||
PutTrainedModelAction.INSTANCE,
|
||||
// security
|
||||
ClearRealmCacheAction.INSTANCE,
|
||||
ClearRolesCacheAction.INSTANCE,
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {
|
||||
|
||||
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
|
||||
public static final String NAME = "cluster:monitor/xpack/ml/inference/put";
|
||||
private PutTrainedModelAction() {
|
||||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends AcknowledgedRequest<Request> {
|
||||
|
||||
public static Request parseRequest(String modelId, XContentParser parser) {
|
||||
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
|
||||
|
||||
if (builder.getModelId() == null) {
|
||||
builder.setModelId(modelId).build();
|
||||
} else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) {
|
||||
// If we have model_id in both URI and body, they must be identical
|
||||
throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID,
|
||||
TrainedModelConfig.MODEL_ID.getPreferredName(),
|
||||
builder.getModelId(),
|
||||
modelId));
|
||||
}
|
||||
// Validations are done against the builder so we can build the full config object.
|
||||
// This allows us to not worry about serializing a builder class between nodes.
|
||||
return new Request(builder.validate(true).build());
|
||||
}
|
||||
|
||||
private final TrainedModelConfig config;
|
||||
|
||||
public Request(TrainedModelConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.config = new TrainedModelConfig(in);
|
||||
}
|
||||
|
||||
public TrainedModelConfig getTrainedModelConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
config.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Request request = (Request) o;
|
||||
return Objects.equals(config, request.config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return Strings.toString(config);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Response extends ActionResponse implements ToXContentObject {
|
||||
|
||||
private final TrainedModelConfig trainedModelConfig;
|
||||
|
||||
public Response(TrainedModelConfig trainedModelConfig) {
|
||||
this.trainedModelConfig = trainedModelConfig;
|
||||
}
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
trainedModelConfig = new TrainedModelConfig(in);
|
||||
}
|
||||
|
||||
public TrainedModelConfig getResponse() {
|
||||
return trainedModelConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
trainedModelConfig.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return trainedModelConfig.toXContent(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Response response = (Response) o;
|
||||
return Objects.equals(trainedModelConfig, response.trainedModelConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModelConfig);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference;
|
|||
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
|
@ -34,6 +35,9 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||
|
||||
|
||||
public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||
|
@ -352,13 +356,31 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private Long estimatedHeapMemory;
|
||||
private Long estimatedOperations;
|
||||
private LazyModelDefinition definition;
|
||||
private String licenseLevel = License.OperationMode.PLATINUM.description();
|
||||
private String licenseLevel;
|
||||
|
||||
public Builder() {}
|
||||
|
||||
public Builder(TrainedModelConfig config) {
|
||||
this.modelId = config.getModelId();
|
||||
this.createdBy = config.getCreatedBy();
|
||||
this.version = config.getVersion();
|
||||
this.createTime = config.getCreateTime();
|
||||
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
|
||||
this.description = config.getDescription();
|
||||
this.tags = config.getTags();
|
||||
this.metadata = config.getMetadata();
|
||||
this.input = config.getInput();
|
||||
}
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return this.modelId;
|
||||
}
|
||||
|
||||
public Builder setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy;
|
||||
return this;
|
||||
|
@ -466,51 +488,96 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
// TODO move to REST level instead of here in the builder
|
||||
public void validate() {
|
||||
// We require a definition to be available here even though it will be stored in a different doc
|
||||
ExceptionsHelper.requireNonNull(definition, DEFINITION);
|
||||
ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
|
||||
if (MlStrings.isValidId(modelId) == false) {
|
||||
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
|
||||
}
|
||||
|
||||
if (MlStrings.hasValidLengthForId(modelId) == false) {
|
||||
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
|
||||
MODEL_ID.getPreferredName(),
|
||||
modelId,
|
||||
MlStrings.ID_LENGTH_LIMIT));
|
||||
}
|
||||
|
||||
checkIllegalSetting(version, VERSION.getPreferredName());
|
||||
checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
|
||||
checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
|
||||
checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
|
||||
checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
|
||||
checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
|
||||
public Builder validate() {
|
||||
return validate(false);
|
||||
}
|
||||
|
||||
private static void checkIllegalSetting(Object value, String setting) {
|
||||
if (value != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
|
||||
/**
|
||||
* Runs validations against the builder.
|
||||
* @return The current builder object if validations are successful
|
||||
* @throws ActionRequestValidationException when there are validation failures.
|
||||
*/
|
||||
public Builder validate(boolean forCreation) {
|
||||
// We require a definition to be available here even though it will be stored in a different doc
|
||||
ActionRequestValidationException validationException = null;
|
||||
if (definition == null) {
|
||||
validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException);
|
||||
}
|
||||
if (modelId == null) {
|
||||
validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
|
||||
}
|
||||
|
||||
if (modelId != null && MlStrings.isValidId(modelId) == false) {
|
||||
validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
|
||||
TrainedModelConfig.MODEL_ID.getPreferredName(),
|
||||
modelId),
|
||||
validationException);
|
||||
}
|
||||
if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) {
|
||||
validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG,
|
||||
TrainedModelConfig.MODEL_ID.getPreferredName(),
|
||||
modelId,
|
||||
MlStrings.ID_LENGTH_LIMIT), validationException);
|
||||
}
|
||||
List<String> badTags = tags.stream()
|
||||
.filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false)
|
||||
.collect(Collectors.toList());
|
||||
if (badTags.isEmpty() == false) {
|
||||
validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS,
|
||||
badTags,
|
||||
MlStrings.ID_LENGTH_LIMIT),
|
||||
validationException);
|
||||
}
|
||||
|
||||
for(String tag : tags) {
|
||||
if (tag.equals(modelId)) {
|
||||
validationException = addValidationError("none of the tags must equal the model_id", validationException);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (forCreation) {
|
||||
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
|
||||
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
|
||||
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
|
||||
validationException = checkIllegalSetting(estimatedHeapMemory,
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
||||
validationException);
|
||||
validationException = checkIllegalSetting(estimatedOperations,
|
||||
ESTIMATED_OPERATIONS.getPreferredName(),
|
||||
validationException);
|
||||
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
|
||||
}
|
||||
|
||||
if (validationException != null) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
private static ActionRequestValidationException checkIllegalSetting(Object value,
|
||||
String setting,
|
||||
ActionRequestValidationException validationException) {
|
||||
if (value != null) {
|
||||
return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException);
|
||||
}
|
||||
return validationException;
|
||||
}
|
||||
|
||||
public TrainedModelConfig build() {
|
||||
return new TrainedModelConfig(
|
||||
modelId,
|
||||
createdBy,
|
||||
version,
|
||||
createdBy == null ? "user" : createdBy,
|
||||
version == null ? Version.CURRENT : version,
|
||||
description,
|
||||
createTime == null ? Instant.now() : createTime,
|
||||
definition,
|
||||
tags,
|
||||
metadata,
|
||||
input,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations,
|
||||
licenseLevel);
|
||||
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
|
||||
estimatedOperations == null ? 0 : estimatedOperations,
|
||||
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -531,6 +598,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return new LazyModelDefinition(input.readString(), null);
|
||||
}
|
||||
|
||||
private LazyModelDefinition(LazyModelDefinition definition) {
|
||||
if (definition != null) {
|
||||
this.compressedString = definition.compressedString;
|
||||
this.parsedDefinition = definition.parsedDefinition;
|
||||
}
|
||||
}
|
||||
|
||||
private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
|
||||
if (compressedString == null && trainedModelDefinition == null) {
|
||||
throw new IllegalArgumentException("unexpected null model definition");
|
||||
|
|
|
@ -179,6 +179,12 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
|
|||
this(true);
|
||||
}
|
||||
|
||||
public Builder(TrainedModelDefinition definition) {
|
||||
this(true);
|
||||
this.preProcessors = new ArrayList<>(definition.getPreProcessors());
|
||||
this.trainedModel = definition.trainedModel;
|
||||
}
|
||||
|
||||
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||
this.preProcessors = preProcessors;
|
||||
return this;
|
||||
|
|
|
@ -95,6 +95,10 @@ public final class Messages {
|
|||
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
|
||||
"Getting model definition is not supported when getting more than one model";
|
||||
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
|
||||
public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " +
|
||||
"hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters.";
|
||||
public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
|
||||
public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
|
||||
|
||||
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
||||
public static final String JOB_AUDIT_CREATED = "Job created";
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
|
||||
public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
String modelId = randomAlphaOfLength(10);
|
||||
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Request> instanceReader() {
|
||||
return (in) -> {
|
||||
Request request = new Request(in);
|
||||
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry());
|
||||
return request;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
|
||||
public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
String modelId = randomAlphaOfLength(10);
|
||||
return new Response(TrainedModelConfigTests.createTestInstance(modelId)
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Response> instanceReader() {
|
||||
return (in) -> {
|
||||
Response response = new Response(in);
|
||||
response.getResponse().ensureParsedDefinition(xContentRegistry());
|
||||
return response;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
}
|
|
@ -5,8 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
@ -56,14 +56,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
return TrainedModelConfig.builder()
|
||||
.setInput(TrainedModelInputTests.createRandomInput())
|
||||
.setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
|
||||
.setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
|
||||
.setCreateTime(Instant.ofEpochMilli(randomLongBetween(Instant.MIN.getEpochSecond(), Instant.MAX.getEpochSecond())))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setModelId(modelId)
|
||||
.setCreatedBy(randomAlphaOfLength(10))
|
||||
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
|
||||
.setEstimatedHeapMemory(randomNonNegativeLong())
|
||||
.setEstimatedOperations(randomNonNegativeLong())
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(),
|
||||
License.OperationMode.GOLD.description(),
|
||||
License.OperationMode.BASIC.description()))
|
||||
.setTags(tags);
|
||||
}
|
||||
|
||||
|
@ -191,50 +193,52 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
}
|
||||
|
||||
public void testValidateWithNullDefinition() {
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
|
||||
assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
|
||||
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder().validate());
|
||||
assertThat(ex.getMessage(), containsString("[definition] must not be null."));
|
||||
}
|
||||
|
||||
public void testValidateWithInvalidID() {
|
||||
String modelId = "InvalidID-";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
||||
assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
||||
}
|
||||
|
||||
public void testValidateWithLongID() {
|
||||
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
||||
assertThat(ex.getMessage(),
|
||||
containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
||||
}
|
||||
|
||||
public void testValidateWithIllegallyUserProvidedFields() {
|
||||
String modelId = "simplemodel";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setCreateTime(Instant.now())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
|
||||
.setModelId(modelId).validate(true));
|
||||
assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class,
|
||||
ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setVersion(Version.CURRENT)
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
||||
.setModelId(modelId).validate(true));
|
||||
assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class,
|
||||
ex = expectThrows(ActionRequestValidationException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setCreatedBy("ml_user")
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
||||
.setModelId(modelId).validate(true));
|
||||
assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
|
||||
}
|
||||
|
||||
public void testSerializationWithLazyDefinition() throws IOException {
|
||||
|
|
|
@ -133,7 +133,6 @@ integTest.runner {
|
|||
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
|
||||
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
|
||||
'ml/inference_crud/Test delete given used trained model',
|
||||
'ml/inference_crud/Test delete given unused trained model',
|
||||
'ml/inference_crud/Test delete with missing model',
|
||||
'ml/inference_crud/Test get given missing trained model',
|
||||
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
|
||||
|
|
|
@ -9,16 +9,20 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
|||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -35,26 +39,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
|
||||
@Before
|
||||
public void createBothModels() throws Exception {
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId("test_classification")
|
||||
.setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinitionDoc.docId("test_classification", 0))
|
||||
.setSource(buildClassificationModelDoc(), XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId("test_regression")
|
||||
.setSource(REGRESSION_CONFIG, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinitionDoc.docId("test_regression", 0))
|
||||
.setSource(buildRegressionModelDoc(), XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
|
||||
}
|
||||
|
||||
@After
|
||||
public void deleteBothModels() {
|
||||
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
|
||||
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
|
||||
}
|
||||
|
||||
public void testPipelineCreationAndDeletion() throws Exception {
|
||||
|
@ -392,6 +384,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for regression\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"definition\": " + REGRESSION_DEFINITION + ","+
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
|
@ -519,28 +512,27 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" }\n" +
|
||||
"}";
|
||||
|
||||
private static String buildClassificationModelDoc() throws IOException {
|
||||
String compressed =
|
||||
InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
|
||||
return modelDocString(compressed, "test_classification");
|
||||
private TrainedModelConfig buildClassificationModel() throws IOException {
|
||||
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(CLASSIFICATION_CONFIG),
|
||||
XContentType.JSON)) {
|
||||
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
}
|
||||
|
||||
private static String buildRegressionModelDoc() throws IOException {
|
||||
String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
|
||||
return modelDocString(compressed, "test_regression");
|
||||
private TrainedModelConfig buildRegressionModel() throws IOException {
|
||||
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(REGRESSION_CONFIG),
|
||||
XContentType.JSON)) {
|
||||
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
}
|
||||
|
||||
private static String modelDocString(String compressedDefinition, String modelId) {
|
||||
return "" +
|
||||
"{" +
|
||||
"\"model_id\": \"" + modelId + "\",\n" +
|
||||
"\"doc_num\": 0,\n" +
|
||||
"\"doc_type\": \"trained_model_definition_doc\",\n" +
|
||||
" \"compression_version\": " + 1 + ",\n" +
|
||||
" \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
" \"definition_length\": " + compressedDefinition.length() + ",\n" +
|
||||
"\"definition\": \"" + compressedDefinition + "\"\n" +
|
||||
"}";
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
private static final String CLASSIFICATION_CONFIG = "" +
|
||||
|
@ -549,8 +541,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"created_by\": \"es_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0\n" +
|
||||
|
|
|
@ -6,10 +6,18 @@
|
|||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
|
@ -18,26 +26,19 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
@ -62,22 +63,8 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
public void testGetTrainedModels() throws IOException {
|
||||
String modelId = "a_test_regression_model";
|
||||
String modelId2 = "a_test_regression_model-2";
|
||||
Request model1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
||||
model1.setJsonEntity(buildRegressionModel(modelId));
|
||||
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request modelDefinition1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
|
||||
modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
|
||||
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request model2 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
|
||||
model2.setJsonEntity(buildRegressionModel(modelId2));
|
||||
assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
|
||||
putRegressionModel(modelId);
|
||||
putRegressionModel(modelId2);
|
||||
Response getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||
|
||||
|
@ -164,17 +151,7 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
|
||||
public void testDeleteTrainedModels() throws IOException {
|
||||
String modelId = "test_delete_regression_model";
|
||||
Request model1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
||||
model1.setJsonEntity(buildRegressionModel(modelId));
|
||||
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request modelDefinition1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
|
||||
modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
|
||||
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
|
||||
putRegressionModel(modelId);
|
||||
|
||||
Response delModel = client().performRequest(new Request("DELETE",
|
||||
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||
|
@ -208,42 +185,68 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
assertThat(response, containsString("\"definition\""));
|
||||
}
|
||||
|
||||
private static String buildRegressionModel(String modelId) throws IOException {
|
||||
private void putRegressionModel(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Collections.emptyList())
|
||||
.setTrainedModel(buildRegression());
|
||||
TrainedModelConfig.builder()
|
||||
.setDefinition(definition)
|
||||
.setModelId(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
|
||||
.setCreatedBy("ml_test")
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build()
|
||||
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
|
||||
.build().toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
Request model = new Request("PUT", "_ml/inference/" + modelId);
|
||||
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
|
||||
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
|
||||
}
|
||||
}
|
||||
|
||||
private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Collections.emptyList())
|
||||
.setTrainedModel(LocalModelTests.buildRegression())
|
||||
.build();
|
||||
String compressedString = InferenceToXContentCompressor.deflate(definition);
|
||||
TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0)
|
||||
.setCompressedString(compressedString)
|
||||
.setTotalDefinitionLength(compressedString.length())
|
||||
.setDefinitionLength(compressedString.length())
|
||||
.setCompressionVersion(1)
|
||||
.setModelId(modelId).build();
|
||||
doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
|
||||
}
|
||||
private static TrainedModel buildRegression() {
|
||||
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5),
|
||||
TreeNode.builder(1).setLeafValue(0.3),
|
||||
TreeNode.builder(2)
|
||||
.setThreshold(0.0)
|
||||
.setSplitFeature(3)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4),
|
||||
TreeNode.builder(3).setLeafValue(0.1),
|
||||
TreeNode.builder(4).setLeafValue(0.2))
|
||||
.build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(2)
|
||||
.setThreshold(1.0),
|
||||
TreeNode.builder(1).setLeafValue(1.5),
|
||||
TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(0.2),
|
||||
TreeNode.builder(1).setLeafValue(1.5),
|
||||
TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5)))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@After
|
||||
public void clearMlState() throws Exception {
|
||||
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
|
||||
|
|
|
@ -112,6 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
||||
|
@ -183,6 +184,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportPutFilterAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportPutJobAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
|
||||
|
@ -273,6 +275,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
|
|||
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
|
||||
|
@ -773,7 +776,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
new RestExplainDataFrameAnalyticsAction(restController),
|
||||
new RestGetTrainedModelsAction(restController),
|
||||
new RestDeleteTrainedModelAction(restController),
|
||||
new RestGetTrainedModelsStatsAction(restController)
|
||||
new RestGetTrainedModelsStatsAction(restController),
|
||||
new RestPutTrainedModelAction(restController)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -844,7 +848,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
|
||||
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
|
||||
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
|
||||
new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class)
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.block.ClusterBlockException;
|
||||
import org.elasticsearch.cluster.block.ClusterBlockLevel;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
|
||||
public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Request, Response> {
|
||||
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final NamedXContentRegistry xContentRegistry;
|
||||
private final Client client;
|
||||
|
||||
@Inject
|
||||
public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService,
|
||||
ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver, Client client,
|
||||
TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
|
||||
super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new,
|
||||
indexNameExpressionResolver);
|
||||
this.licenseState = licenseState;
|
||||
this.trainedModelProvider = trainedModelProvider;
|
||||
this.xContentRegistry = xContentRegistry;
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Response read(StreamInput in) throws IOException {
|
||||
return new Response(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
|
||||
try {
|
||||
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
|
||||
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
|
||||
} catch (IOException ex) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
|
||||
ex,
|
||||
request.getTrainedModelConfig().getModelId()));
|
||||
return;
|
||||
} catch (ElasticsearchException ex) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.",
|
||||
ex,
|
||||
request.getTrainedModelConfig().getModelId()));
|
||||
return;
|
||||
}
|
||||
|
||||
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
.setCreatedBy("api_user")
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
|
||||
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
|
||||
.build();
|
||||
|
||||
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
|
||||
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
|
||||
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
|
||||
listener::onFailure
|
||||
)),
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
ActionListener<Void> modelIdTagCheckListener = ActionListener.wrap(
|
||||
r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener),
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener);
|
||||
}
|
||||
|
||||
private void checkModelIdAgainstTags(String modelId, ActionListener<Void> listener) {
|
||||
QueryBuilder builder = QueryBuilders.constantScoreQuery(
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId)));
|
||||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
|
||||
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ML_ORIGIN,
|
||||
searchRequest,
|
||||
ActionListener.<SearchResponse>wrap(
|
||||
response -> {
|
||||
if (response.getHits().getTotalHits().value > 0) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.badRequestException(
|
||||
Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId)));
|
||||
return;
|
||||
}
|
||||
listener.onResponse(null);
|
||||
},
|
||||
listener::onFailure
|
||||
),
|
||||
client::search);
|
||||
}
|
||||
|
||||
private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> listener) {
|
||||
if (tags.isEmpty()) {
|
||||
listener.onResponse(null);
|
||||
return;
|
||||
}
|
||||
|
||||
QueryBuilder builder = QueryBuilders.constantScoreQuery(
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
|
||||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
|
||||
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ML_ORIGIN,
|
||||
searchRequest,
|
||||
ActionListener.<SearchResponse>wrap(
|
||||
response -> {
|
||||
if (response.getHits().getTotalHits().value > 0) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags)));
|
||||
return;
|
||||
}
|
||||
listener.onResponse(null);
|
||||
},
|
||||
listener::onFailure
|
||||
),
|
||||
client::search);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
|
||||
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
|
||||
if (licenseState.isMachineLearningAllowed()) {
|
||||
super.doExecute(task, request, listener);
|
||||
} else {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -175,10 +175,12 @@ public class TrainedModelProvider {
|
|||
r -> {
|
||||
assert r.getItems().length == 2;
|
||||
if (r.getItems()[0].isFailed()) {
|
||||
|
||||
logger.error(new ParameterizedMessage(
|
||||
"[{}] failed to store trained model config for inference",
|
||||
trainedModelConfig.getModelId()),
|
||||
r.getItems()[0].getFailure().getCause());
|
||||
|
||||
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RestPutTrainedModelAction extends BaseRestHandler {
|
||||
|
||||
public RestPutTrainedModelAction(RestController controller) {
|
||||
controller.registerHandler(RestRequest.Method.PUT,
|
||||
MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}",
|
||||
this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "xpack_ml_put_trained_model_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
XContentParser parser = restRequest.contentParser();
|
||||
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
|
||||
putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));
|
||||
|
||||
return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -13,7 +13,6 @@ import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
|||
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.client.transport.TransportClient;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
|
@ -37,26 +36,30 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.client.MachineLearningClient;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasItem;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
@ -561,12 +564,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
" \"target_field\": \"regression_value\",\n" +
|
||||
" \"model_id\": \"modelprocessorlicensetest\",\n" +
|
||||
" \"inference_config\": {\"regression\": {}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" \"field_mappings\": {}\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
// Creating a pipeline should work
|
||||
|
@ -748,76 +746,22 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
||||
}
|
||||
|
||||
private void putInferenceModel(String modelId) throws Exception {
|
||||
String config = "" +
|
||||
"{\n" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
|
||||
" \"estimated_operations\": 0,\n" +
|
||||
" \"created_time\": 0\n" +
|
||||
"}";
|
||||
String definition = "" +
|
||||
"{" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }" +
|
||||
"}";
|
||||
String compressedDefinitionString =
|
||||
InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8)));
|
||||
String compressedDefinition = "" +
|
||||
"{" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" +
|
||||
" \"doc_num\": " + 0 + ",\n" +
|
||||
" \"compression_version\": " + 1 + ",\n" +
|
||||
" \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" +
|
||||
" \"definition_length\": " + compressedDefinitionString.length() + ",\n" +
|
||||
" \"definition\": \"" + compressedDefinitionString + "\"\n" +
|
||||
"}";
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(modelId)
|
||||
.setSource(config, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
|
||||
.setSource(compressedDefinition, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
private void putInferenceModel(String modelId) {
|
||||
TrainedModelConfig config = TrainedModelConfig.builder()
|
||||
.setParsedDefinition(
|
||||
new TrainedModelDefinition.Builder()
|
||||
.setTrainedModel(
|
||||
Tree.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(Arrays.asList("feature1"))
|
||||
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
|
||||
.build())
|
||||
.setPreProcessors(Collections.emptyList()))
|
||||
.setModelId(modelId)
|
||||
.setDescription("test model for classification")
|
||||
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
|
||||
.build();
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
|
||||
}
|
||||
|
||||
private static OperationMode randomInvalidLicenseType() {
|
||||
|
|
|
@ -200,7 +200,6 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
|
||||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setVersion(Version.CURRENT)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"ml.put_trained_model":{
|
||||
"documentation":{
|
||||
"url":"TODO"
|
||||
},
|
||||
"stability":"experimental",
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/inference/{model_id}",
|
||||
"methods":[
|
||||
"PUT"
|
||||
],
|
||||
"parts":{
|
||||
"model_id":{
|
||||
"type":"string",
|
||||
"description":"The ID of the trained models to store"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"body": {
|
||||
"description":"The trained model configuration",
|
||||
"required":true
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,74 @@
|
|||
setup:
|
||||
- skip:
|
||||
features: headers
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: a-regression-model-0
|
||||
body: >
|
||||
{
|
||||
"description": "empty model for tests",
|
||||
"input": {"field_names": ["field1", "field2"]},
|
||||
"definition": {
|
||||
"preprocessors": [],
|
||||
"trained_model": {
|
||||
"tree": {
|
||||
"feature_names": ["field1", "field2"],
|
||||
"tree_structure": [
|
||||
{"node_index": 0, "leaf_value": 1}
|
||||
],
|
||||
"target_type": "regression"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: a-regression-model-1
|
||||
body: >
|
||||
{
|
||||
"description": "empty model for tests",
|
||||
"input": {"field_names": ["field1", "field2"]},
|
||||
"definition": {
|
||||
"preprocessors": [],
|
||||
"trained_model": {
|
||||
"tree": {
|
||||
"feature_names": ["field1", "field2"],
|
||||
"tree_structure": [
|
||||
{"node_index": 0, "leaf_value": 1}
|
||||
],
|
||||
"target_type": "regression"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: a-classification-model
|
||||
body: >
|
||||
{
|
||||
"description": "empty model for tests",
|
||||
"input": {"field_names": ["field1", "field2"]},
|
||||
"definition": {
|
||||
"preprocessors": [],
|
||||
"trained_model": {
|
||||
"tree": {
|
||||
"feature_names": ["field1", "field2"],
|
||||
"tree_structure": [
|
||||
{"node_index": 0, "leaf_value": 1}
|
||||
],
|
||||
"target_type": "classification",
|
||||
"classification_labels": ["no", "yes"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test get given missing trained model":
|
||||
|
||||
|
@ -24,56 +95,52 @@
|
|||
- match: { count: 0 }
|
||||
- match: { trained_model_configs: [] }
|
||||
---
|
||||
"Test get models":
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
- match: { count: 4 }
|
||||
- match: { trained_model_configs.0.model_id: "a-classification-model" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
|
||||
- match: { trained_model_configs.2.model_id: "a-regression-model-1" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-regression*"
|
||||
- match: { count: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 0
|
||||
size: 2
|
||||
- match: { count: 4 }
|
||||
- match: { trained_model_configs.0.model_id: "a-classification-model" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 1
|
||||
size: 1
|
||||
- match: { count: 4 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||
---
|
||||
"Test delete given unused trained model":
|
||||
|
||||
- do:
|
||||
index:
|
||||
id: trained_model_config-unused-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "unused-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local"
|
||||
}
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
- do:
|
||||
ml.delete_trained_model:
|
||||
model_id: "unused-regression-model"
|
||||
model_id: "a-classification-model"
|
||||
- match: { acknowledged: true }
|
||||
|
||||
---
|
||||
"Test delete with missing model":
|
||||
- do:
|
||||
catch: missing
|
||||
ml.delete_trained_model:
|
||||
model_id: "missing-trained-model"
|
||||
|
||||
---
|
||||
"Test delete given used trained model":
|
||||
- do:
|
||||
index:
|
||||
id: trained_model_config-used-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "used-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local"
|
||||
}
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
- do:
|
||||
ingest.put_pipeline:
|
||||
id: "regression-model-pipeline"
|
||||
|
@ -82,7 +149,7 @@
|
|||
"processors": [
|
||||
{
|
||||
"inference" : {
|
||||
"model_id" : "used-regression-model",
|
||||
"model_id" : "a-regression-model-0",
|
||||
"inference_config": {"regression": {}},
|
||||
"target_field": "regression_field",
|
||||
"field_mappings": {}
|
||||
|
@ -95,12 +162,12 @@
|
|||
- do:
|
||||
catch: conflict
|
||||
ml.delete_trained_model:
|
||||
model_id: "used-regression-model"
|
||||
model_id: "a-regression-model-0"
|
||||
---
|
||||
"Test get pre-packaged trained models":
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "_all"
|
||||
model_id: "lang_ident_model_1"
|
||||
allow_no_match: false
|
||||
- match: { count: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||
|
|
Loading…
Reference in New Issue