[7.x] [ML][Inference] adding .ml-inference* index and storage (#47267) (#47310)

* [ML][Inference] adding .ml-inference* index and storage (#47267)

* [ML][Inference] adding .ml-inference* index and storage

* Addressing PR comments

* Allowing null definition, adding validation tests for model config

* fixing line length

* adjusting for backport
This commit is contained in:
Benjamin Trent 2019-10-01 08:20:33 -04:00 committed by GitHub
parent c43e932a0c
commit 4335e07716
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1643 additions and 65 deletions

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.transform.transforms.util; package org.elasticsearch.client.common;
import org.elasticsearch.common.time.DateFormatters; import org.elasticsearch.common.time.DateFormatters;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
@ -46,6 +46,14 @@ public final class TimeUtil {
"unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]"); "unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
} }
/**
* Parse out an Instant object given the current parser and field name.
*
* @param parser current XContentParser
* @param fieldName the field's preferred name (utilized in exception)
* @return parsed Instant object
* @throws IOException from XContentParser
*/
public static Instant parseTimeFieldToInstant(XContentParser parser, String fieldName) throws IOException { public static Instant parseTimeFieldToInstant(XContentParser parser, String fieldName) throws IOException {
if (parser.currentToken() == XContentParser.Token.VALUE_NUMBER) { if (parser.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return Instant.ofEpochMilli(parser.longValue()); return Instant.ofEpochMilli(parser.longValue());

View File

@ -18,7 +18,7 @@
*/ */
package org.elasticsearch.client.ml.calendars; package org.elasticsearch.client.ml.calendars;
import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;

View File

@ -20,7 +20,7 @@
package org.elasticsearch.client.ml.dataframe; package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.transform.transforms.util.TimeUtil; import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;

View File

@ -0,0 +1,34 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.xcontent.ToXContentObject;
/**
* Simple interface for XContent Objects that are named.
*
* This affords more general handling when serializing and de-serializing this type of XContent when it is used in a NamedObjects
* parser.
*/
public interface NamedXContentObject extends ToXContentObject {
/**
* @return The name of the XContentObject that is to be serialized
*/
String getName();
}

View File

@ -0,0 +1,57 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
public final class NamedXContentObjectHelper {
private NamedXContentObjectHelper() {}
public static XContentBuilder writeNamedObjects(XContentBuilder builder,
ToXContent.Params params,
boolean useExplicitOrder,
String namedObjectsName,
List<? extends NamedXContentObject> namedObjects) throws IOException {
if (useExplicitOrder) {
builder.startArray(namedObjectsName);
} else {
builder.startObject(namedObjectsName);
}
for (NamedXContentObject object : namedObjects) {
if (useExplicitOrder) {
builder.startObject();
}
builder.field(object.getName(), object, params);
if (useExplicitOrder) {
builder.endObject();
}
}
if (useExplicitOrder) {
builder.endArray();
} else {
builder.endObject();
}
return builder;
}
}

View File

@ -0,0 +1,299 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_doc";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
public static final ParseField CREATED_TIME = new ParseField("created_time");
public static final ParseField MODEL_VERSION = new ParseField("model_version");
public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField METADATA = new ParseField("metadata");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
TrainedModelConfig.Builder::new);
static {
PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
PARSER.declareField(TrainedModelConfig.Builder::setCreatedTime,
(p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
CREATED_TIME,
ObjectParser.ValueType.VALUE);
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter client side */ },
DEFINITION);
}
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final String modelId;
private final String createdBy;
private final Version version;
private final String description;
private final Instant createdTime;
private final Long modelVersion;
private final String modelType;
private final Map<String, Object> metadata;
private final TrainedModel definition;
TrainedModelConfig(String modelId,
String createdBy,
Version version,
String description,
Instant createdTime,
Long modelVersion,
String modelType,
TrainedModel definition,
Map<String, Object> metadata) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
this.createdTime = Instant.ofEpochMilli(createdTime.toEpochMilli());
this.modelType = modelType;
this.definition = definition;
this.description = description;
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.modelVersion = modelVersion;
}
public String getModelId() {
return modelId;
}
public String getCreatedBy() {
return createdBy;
}
public Version getVersion() {
return version;
}
public String getDescription() {
return description;
}
public Instant getCreatedTime() {
return createdTime;
}
public Long getModelVersion() {
return modelVersion;
}
public String getModelType() {
return modelType;
}
public Map<String, Object> getMetadata() {
return metadata;
}
public TrainedModel getDefinition() {
return definition;
}
public static Builder builder() {
return new Builder();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (modelId != null) {
builder.field(MODEL_ID.getPreferredName(), modelId);
}
if (createdBy != null) {
builder.field(CREATED_BY.getPreferredName(), createdBy);
}
if (version != null) {
builder.field(VERSION.getPreferredName(), version.toString());
}
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
if (createdTime != null) {
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
}
if (modelVersion != null) {
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
}
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType);
}
if (definition != null) {
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
DEFINITION.getPreferredName(),
Collections.singletonList(definition));
}
if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata);
}
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId) &&
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Objects.equals(createdTime, that.createdTime) &&
Objects.equals(modelVersion, that.modelVersion) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(definition, that.definition) &&
Objects.equals(metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(modelId,
createdBy,
version,
createdTime,
modelType,
definition,
description,
metadata,
modelVersion);
}
public static class Builder {
private String modelId;
private String createdBy;
private Version version;
private String description;
private Instant createdTime;
private Long modelVersion;
private String modelType;
private Map<String, Object> metadata;
private TrainedModel definition;
public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}
private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
private Builder setVersion(Version version) {
this.version = version;
return this;
}
private Builder setVersion(String version) {
return this.setVersion(Version.fromString(version));
}
public Builder setDescription(String description) {
this.description = description;
return this;
}
private Builder setCreatedTime(Instant createdTime) {
this.createdTime = createdTime;
return this;
}
public Builder setModelVersion(Long modelVersion) {
this.modelVersion = modelVersion;
return this;
}
public Builder setModelType(String modelType) {
this.modelType = modelType;
return this;
}
public Builder setMetadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}
public Builder setDefinition(TrainedModel definition) {
this.definition = definition;
return this;
}
private Builder setDefinition(List<TrainedModel> definition) {
assert definition.size() == 1;
return setDefinition(definition.get(0));
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
createdBy,
version,
description,
createdTime,
modelVersion,
modelType,
definition,
metadata);
}
}
}

View File

@ -18,11 +18,11 @@
*/ */
package org.elasticsearch.client.ml.inference.trainedmodel; package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.client.ml.inference.NamedXContentObject;
import java.util.List; import java.util.List;
public interface TrainedModel extends ToXContentObject { public interface TrainedModel extends NamedXContentObject {
/** /**
* @return List of featureNames expected by the model. In the order that they are expected * @return List of featureNames expected by the model. In the order that they are expected

View File

@ -18,7 +18,7 @@
*/ */
package org.elasticsearch.client.ml.job.config; package org.elasticsearch.client.ml.job.config;
import org.elasticsearch.client.ml.job.util.TimeUtil; import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.process; package org.elasticsearch.client.ml.job.process;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,9 +18,9 @@
*/ */
package org.elasticsearch.client.ml.job.process; package org.elasticsearch.client.ml.job.process;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.results.Result; import org.elasticsearch.client.ml.job.results.Result;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -19,8 +19,8 @@
package org.elasticsearch.client.ml.job.process; package org.elasticsearch.client.ml.job.process;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.results; package org.elasticsearch.client.ml.job.results;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.results; package org.elasticsearch.client.ml.job.results;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.results; package org.elasticsearch.client.ml.job.results;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.results; package org.elasticsearch.client.ml.job.results;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ObjectParser.ValueType;

View File

@ -18,8 +18,8 @@
*/ */
package org.elasticsearch.client.ml.job.results; package org.elasticsearch.client.ml.job.results;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.job.config.Job; import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.TimeUtil;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;

View File

@ -1,48 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.job.util;
import org.elasticsearch.common.time.DateFormatters;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.time.format.DateTimeFormatter;
import java.util.Date;
public final class TimeUtil {
/**
* Parse out a Date object given the current parser and field name.
*
* @param parser current XContentParser
* @param fieldName the field's preferred name (utilized in exception)
* @return parsed Date object
* @throws IOException from XContentParser
*/
public static Date parseTimeField(XContentParser parser, String fieldName) throws IOException {
if (parser.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return new Date(parser.longValue());
} else if (parser.currentToken() == XContentParser.Token.VALUE_STRING) {
return new Date(DateFormatters.from(DateTimeFormatter.ISO_INSTANT.parse(parser.text())).toInstant().toEpochMilli());
}
throw new IllegalArgumentException(
"unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
}
}

View File

@ -19,7 +19,7 @@
package org.elasticsearch.client.transform.transforms; package org.elasticsearch.client.transform.transforms;
import org.elasticsearch.client.transform.transforms.util.TimeUtil; import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;

View File

@ -20,8 +20,8 @@
package org.elasticsearch.client.transform.transforms; package org.elasticsearch.client.transform.transforms;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.transform.transforms.pivot.PivotConfig; import org.elasticsearch.client.transform.transforms.pivot.PivotConfig;
import org.elasticsearch.client.transform.transforms.util.TimeUtil;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;

View File

@ -0,0 +1,112 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public class NamedXContentObjectHelperTests extends ESTestCase {
static class NamedTestObject implements NamedXContentObject {
private String fieldValue;
public static final ObjectParser<NamedTestObject, Void> PARSER =
new ObjectParser<>("my_named_object", true, NamedTestObject::new);
static {
PARSER.declareString(NamedTestObject::setFieldValue, new ParseField("my_field"));
}
NamedTestObject() {
}
NamedTestObject(String value) {
this.fieldValue = value;
}
@Override
public String getName() {
return "my_named_object";
}
public void setFieldValue(String value) {
this.fieldValue = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldValue != null) {
builder.field("my_field", fieldValue);
}
builder.endObject();
return builder;
}
}
public void testSerializeInOrder() throws IOException {
String expected =
"{\"my_objects\":[{\"my_named_object\":{\"my_field\":\"value1\"}},{\"my_named_object\":{\"my_field\":\"value2\"}}]}";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
List<NamedXContentObject> objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2"));
NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, true, "my_objects", objects);
builder.endObject();
assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected));
}
}
public void testSerialize() throws IOException {
String expected = "{\"my_objects\":{\"my_named_object\":{\"my_field\":\"value1\"},\"my_named_object\":{\"my_field\":\"value2\"}}}";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
List<NamedXContentObject> objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2"));
NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, false, "my_objects", objects);
builder.endObject();
assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected));
}
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(Collections.singletonList(new NamedXContentRegistry.Entry(NamedXContentObject.class,
new ParseField("my_named_object"),
(p, c) -> NamedTestObject.PARSER.apply(p, null))));
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -0,0 +1,76 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
@Override
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
return TrainedModelConfig.fromXContent(parser).build();
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected TrainedModelConfig createTestInstance() {
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10),
randomFrom(TreeTests.createRandom()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -0,0 +1,372 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_doc";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
public static final ParseField CREATED_TIME = new ParseField("created_time");
public static final ParseField MODEL_VERSION = new ParseField("model_version");
public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField METADATA = new ParseField("metadata");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
public static final ObjectParser<TrainedModelConfig.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boolean ignoreUnknownFields) {
ObjectParser<TrainedModelConfig.Builder, Void> parser = new ObjectParser<>(NAME,
ignoreUnknownFields,
TrainedModelConfig.Builder::new);
parser.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
parser.declareField(TrainedModelConfig.Builder::setCreatedTime,
(p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
CREATED_TIME,
ObjectParser.ValueType.VALUE);
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
DEFINITION);
return parser;
}
public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public static String documentId(String modelId, long modelVersion) {
return NAME + "-" + modelId + "-" + modelVersion;
}
private final String modelId;
private final String createdBy;
private final Version version;
private final String description;
private final Instant createdTime;
private final long modelVersion;
private final String modelType;
private final Map<String, Object> metadata;
// TODO how to reference and store large models that will not be executed in Java???
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
// TODO Should this be lazily parsed when loading via the index???
private final TrainedModel definition;
TrainedModelConfig(String modelId,
String createdBy,
Version version,
String description,
Instant createdTime,
Long modelVersion,
String modelType,
TrainedModel definition,
Map<String, Object> metadata) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
this.createdTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createdTime, CREATED_TIME).toEpochMilli());
this.modelType = ExceptionsHelper.requireNonNull(modelType, MODEL_TYPE);
this.definition = definition;
this.description = description;
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.modelVersion = modelVersion == null ? 0 : modelVersion;
}
public TrainedModelConfig(StreamInput in) throws IOException {
modelId = in.readString();
createdBy = in.readString();
version = Version.readVersion(in);
description = in.readOptionalString();
createdTime = in.readInstant();
modelVersion = in.readVLong();
modelType = in.readString();
definition = in.readOptionalNamedWriteable(TrainedModel.class);
metadata = in.readMap();
}
public String getModelId() {
return modelId;
}
public String getCreatedBy() {
return createdBy;
}
public Version getVersion() {
return version;
}
public String getDescription() {
return description;
}
public Instant getCreatedTime() {
return createdTime;
}
public long getModelVersion() {
return modelVersion;
}
public String getModelType() {
return modelType;
}
public Map<String, Object> getMetadata() {
return metadata;
}
@Nullable
public TrainedModel getDefinition() {
return definition;
}
public static Builder builder() {
return new Builder();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeString(createdBy);
Version.writeVersion(version, out);
out.writeOptionalString(description);
out.writeInstant(createdTime);
out.writeVLong(modelVersion);
out.writeString(modelType);
out.writeOptionalNamedWriteable(definition);
out.writeMap(metadata);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(CREATED_BY.getPreferredName(), createdBy);
builder.field(VERSION.getPreferredName(), version.toString());
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
builder.field(MODEL_TYPE.getPreferredName(), modelType);
if (definition != null) {
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
DEFINITION.getPreferredName(),
Collections.singletonList(definition));
}
if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata);
}
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelConfig that = (TrainedModelConfig) o;
return Objects.equals(modelId, that.modelId) &&
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Objects.equals(createdTime, that.createdTime) &&
Objects.equals(modelVersion, that.modelVersion) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(definition, that.definition) &&
Objects.equals(metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(modelId,
createdBy,
version,
createdTime,
modelType,
definition,
description,
metadata,
modelVersion);
}
public static class Builder {
private String modelId;
private String createdBy;
private Version version;
private String description;
private Instant createdTime;
private Long modelVersion;
private String modelType;
private Map<String, Object> metadata;
private TrainedModel definition;
public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}
public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
public Builder setVersion(Version version) {
this.version = version;
return this;
}
private Builder setVersion(String version) {
return this.setVersion(Version.fromString(version));
}
public Builder setDescription(String description) {
this.description = description;
return this;
}
public Builder setCreatedTime(Instant createdTime) {
this.createdTime = createdTime;
return this;
}
public Builder setModelVersion(Long modelVersion) {
this.modelVersion = modelVersion;
return this;
}
public Builder setModelType(String modelType) {
this.modelType = modelType;
return this;
}
public Builder setMetadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}
public Builder setDefinition(TrainedModel definition) {
this.definition = definition;
return this;
}
private Builder setDefinition(List<TrainedModel> definition) {
if (definition.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
DEFINITION.getPreferredName());
}
return setDefinition(definition.get(0));
}
// TODO move to REST level instead of here in the builder
public void validate() {
// We require a definition to be available until we support other means of supplying the definition
ExceptionsHelper.requireNonNull(definition, DEFINITION);
ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
if (MlStrings.isValidId(modelId) == false) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
}
if (MlStrings.hasValidLengthForId(modelId) == false) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
MODEL_ID.getPreferredName(),
modelId,
MlStrings.ID_LENGTH_LIMIT));
}
if (version != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", VERSION.getPreferredName());
}
if (createdBy != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATED_BY.getPreferredName());
}
if (createdTime != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATED_TIME.getPreferredName());
}
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
createdBy,
version,
description,
createdTime,
modelVersion,
modelType,
definition,
metadata);
}
public TrainedModelConfig build(Version version) {
return new TrainedModelConfig(
modelId,
createdBy,
version,
description,
Instant.now(),
modelVersion,
modelType,
definition,
metadata);
}
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.persistence;
/**
* Class containing the index constants so that the index version, name, and prefix are available to a wider audience.
*/
public final class InferenceIndexConstants {
public static final String INDEX_VERSION = "000001";
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
private InferenceIndexConstants() {}
}

View File

@ -75,9 +75,16 @@ public final class Messages {
"Inconsistent {0}; ''{1}'' specified in the body differs from ''{2}'' specified as a URL argument"; "Inconsistent {0}; ''{1}'' specified in the body differs from ''{2}'' specified as a URL argument";
public static final String INVALID_ID = "Invalid {0}; ''{1}'' can contain lowercase alphanumeric (a-z and 0-9), hyphens or " + public static final String INVALID_ID = "Invalid {0}; ''{1}'' can contain lowercase alphanumeric (a-z and 0-9), hyphens or " +
"underscores; must start and end with alphanumeric"; "underscores; must start and end with alphanumeric";
public static final String ID_TOO_LONG = "Invalid {0}; ''{1}'' cannot contain more than {2} characters.";
public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" + public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" +
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric"; " (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] with version [{1}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
"Failed to serialize the trained model [{0}] with version [{1}] for storage";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created"; public static final String JOB_AUDIT_CREATED = "Job created";
public static final String JOB_AUDIT_UPDATED = "Job updated: {0}"; public static final String JOB_AUDIT_UPDATED = "Job updated: {0}";

View File

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
public final class NamedXContentObjectHelper {
private NamedXContentObjectHelper() {}
public static XContentBuilder writeNamedObjects(XContentBuilder builder,
ToXContent.Params params,
boolean useExplicitOrder,
String namedObjectsName,
List<? extends NamedXContentObject> namedObjects) throws IOException {
if (useExplicitOrder) {
builder.startArray(namedObjectsName);
} else {
builder.startObject(namedObjectsName);
}
for (NamedXContentObject object : namedObjects) {
if (useExplicitOrder) {
builder.startObject();
}
builder.field(object.getName(), object, params);
if (useExplicitOrder) {
builder.endObject();
}
}
if (useExplicitOrder) {
builder.endArray();
} else {
builder.endObject();
}
return builder;
}
}

View File

@ -0,0 +1,134 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
return TrainedModelConfig.fromXContent(parser, lenient).build();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected TrainedModelConfig createTestInstance() {
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10),
randomBoolean() ? null : randomFrom(TreeTests.createRandom()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}
@Override
protected Writeable.Reader<TrainedModelConfig> instanceReader() {
return TrainedModelConfig::new;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
public void testValidateWithNullDefinition() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
}
public void testValidateWithInvalidID() {
String modelId = "InvalidID-";
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
}
public void testValidateWithLongID() {
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
}
public void testValidateWithIllegallyUserProvidedFields() {
String modelId = "simplemodel";
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom()))
.setCreatedTime(Instant.now())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom()))
.setVersion(Version.CURRENT)
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom()))
.setCreatedBy("ml_user")
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
}
}

View File

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.client.ml.inference.NamedXContentObject;
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public class NamedXContentObjectHelperTests extends ESTestCase {
static class NamedTestObject implements NamedXContentObject {
private String fieldValue;
public static final ObjectParser<NamedTestObject, Void> PARSER =
new ObjectParser<>("my_named_object", true, NamedTestObject::new);
static {
PARSER.declareString(NamedTestObject::setFieldValue, new ParseField("my_field"));
}
NamedTestObject() {
}
NamedTestObject(String value) {
this.fieldValue = value;
}
@Override
public String getName() {
return "my_named_object";
}
void setFieldValue(String value) {
this.fieldValue = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldValue != null) {
builder.field("my_field", fieldValue);
}
builder.endObject();
return builder;
}
}
public void testSerializeInOrder() throws IOException {
String expected =
"{\"my_objects\":[{\"my_named_object\":{\"my_field\":\"value1\"}},{\"my_named_object\":{\"my_field\":\"value2\"}}]}";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
List<NamedXContentObject> objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2"));
NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, true, "my_objects", objects);
builder.endObject();
assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected));
}
}
public void testSerialize() throws IOException {
String expected = "{\"my_objects\":{\"my_named_object\":{\"my_field\":\"value1\"},\"my_named_object\":{\"my_field\":\"value2\"}}}";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
List<NamedXContentObject> objects = Arrays.asList(new NamedTestObject("value1"), new NamedTestObject("value2"));
NamedXContentObjectHelper.writeNamedObjects(builder, ToXContent.EMPTY_PARAMS, false, "my_objects", objects);
builder.endObject();
assertThat(BytesReference.bytes(builder).utf8ToString(), equalTo(expected));
}
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(Collections.singletonList(new NamedXContentRegistry.Entry(NamedXContentObject.class,
new ParseField("my_named_object"),
(p, c) -> NamedTestObject.PARSER.apply(p, null))));
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -123,6 +123,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
@ -198,6 +199,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcess
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
@ -906,6 +908,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
logger.error("Error loading the template for the " + AnomalyDetectorsIndex.jobResultsIndexPrefix() + " indices", e); logger.error("Error loading the template for the " + AnomalyDetectorsIndex.jobResultsIndexPrefix() + " indices", e);
} }
try {
templates.put(InferenceIndexConstants.LATEST_INDEX_NAME, InferenceInternalIndex.getIndexTemplateMetaData());
} catch (IOException e) {
logger.error("Error loading the template for the " + InferenceIndexConstants.LATEST_INDEX_NAME + " index", e);
}
return templates; return templates;
}; };
} }
@ -917,7 +925,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
AuditorField.NOTIFICATIONS_INDEX, AuditorField.NOTIFICATIONS_INDEX,
MlMetaIndex.INDEX_NAME, MlMetaIndex.INDEX_NAME,
AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX, AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX,
AnomalyDetectorsIndex.jobResultsIndexPrefix()); AnomalyDetectorsIndex.jobResultsIndexPrefix(),
InferenceIndexConstants.LATEST_INDEX_NAME);
for (String templateName : templateNames) { for (String templateName : templateNames) {
allPresent = allPresent && TemplateUtils.checkTemplateExistsAndVersionIsGTECurrentVersion(templateName, clusterState); allPresent = allPresent && TemplateUtils.checkTemplateExistsAndVersionIsGTECurrentVersion(templateName, clusterState);
} }

View File

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.persistence;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.IndexTemplateMetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException;
import java.util.Collections;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
import static org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants.LATEST_INDEX_NAME;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DATE;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.addMetaInformation;
/**
* Changelog of internal index versions
*
* Please list changes, increase the version in {@link InferenceInternalIndex} if you are 1st in this release cycle
*
* version 1 (7.5): initial
*/
public final class InferenceInternalIndex {
private InferenceInternalIndex() {}
public static XContentBuilder mappings() throws IOException {
return configMapping(SINGLE_MAPPING_NAME);
}
public static IndexTemplateMetaData getIndexTemplateMetaData() throws IOException {
IndexTemplateMetaData inferenceTemplate = IndexTemplateMetaData.builder(LATEST_INDEX_NAME)
.patterns(Collections.singletonList(LATEST_INDEX_NAME))
.version(Version.CURRENT.id)
.settings(Settings.builder()
// the configurations are expected to be small
.put(IndexMetaData.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetaData.SETTING_AUTO_EXPAND_REPLICAS, "0-1"))
.putMapping(SINGLE_MAPPING_NAME, Strings.toString(mappings()))
.build();
return inferenceTemplate;
}
public static XContentBuilder configMapping(String mappingType) throws IOException {
XContentBuilder builder = jsonBuilder();
builder.startObject();
builder.startObject(mappingType);
addMetaInformation(builder);
// do not allow anything outside of the defined schema
builder.field(DYNAMIC, "false");
builder.startObject(PROPERTIES);
addInferenceDocFields(builder);
return builder.endObject()
.endObject()
.endObject();
}
private static void addInferenceDocFields(XContentBuilder builder) throws IOException {
builder.startObject(TrainedModelConfig.MODEL_ID.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(TrainedModelConfig.CREATED_BY.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(TrainedModelConfig.VERSION.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(TrainedModelConfig.DESCRIPTION.getPreferredName())
.field(TYPE, TEXT)
.endObject()
.startObject(TrainedModelConfig.CREATED_TIME.getPreferredName())
.field(TYPE, DATE)
.endObject()
.startObject(TrainedModelConfig.MODEL_VERSION.getPreferredName())
.field(TYPE, LONG)
.endObject()
.startObject(TrainedModelConfig.DEFINITION.getPreferredName())
.field(ENABLED, false)
.endObject()
.startObject(TrainedModelConfig.MODEL_TYPE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(TrainedModelConfig.METADATA.getPreferredName())
.field(ENABLED, false)
.endObject();
}
}

View File

@ -0,0 +1,134 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.persistence;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.io.IOException;
import java.io.InputStream;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class TrainedModelProvider {
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
private final Client client;
private final NamedXContentRegistry xContentRegistry;
public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
XContentBuilder source = trainedModelConfig.toXContent(builder, ToXContent.EMPTY_PARAMS);
IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
.opType(DocWriteRequest.OpType.CREATE)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.id(TrainedModelConfig.documentId(trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion()))
.source(source);
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest,
ActionListener.wrap(
r -> listener.onResponse(true),
e -> {
logger.error(
new ParameterizedMessage("[{}][{}] failed to store trained model for inference",
trainedModelConfig.getModelId(),
trainedModelConfig.getModelVersion()),
e);
if (e instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS,
trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
RestStatus.INTERNAL_SERVER_ERROR,
e,
trainedModelConfig.getModelId()));
}
}));
} catch (IOException e) {
// not expected to happen but for the sake of completeness
listener.onFailure(new ElasticsearchParseException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_SERIALIZE_MODEL, trainedModelConfig.getModelId()),
e));
}
}
public void getTrainedModel(String modelId, long modelVersion, ActionListener<TrainedModelConfig> listener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery()
.addIds(TrainedModelConfig.documentId(modelId, modelVersion)));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(queryBuilder)
// use sort to get the last
.addSort("_index", SortOrder.DESC)
.setSize(1)
.request();
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest,
ActionListener.wrap(
searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion)));
return;
}
BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef();
parseInferenceDocLenientlyFromSource(source, modelId, modelVersion, listener);
},
listener::onFailure));
}
private void parseInferenceDocLenientlyFromSource(BytesReference source,
String modelId,
long modelVersion,
ActionListener<TrainedModelConfig> modelListener) {
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build());
} catch (Exception e) {
logger.error(new ParameterizedMessage("[{}][{}] failed to parse model", modelId, modelVersion), e);
modelListener.onFailure(e);
}
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class TrainedModelProviderIT extends MlSingleNodeTestCase {
private TrainedModelProvider trainedModelProvider;
@Before
public void createComponents() throws Exception {
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
waitForMlTemplates();
}
public void testPutTrainedModelConfig() throws Exception {
String modelId = "test-put-trained-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
}
public void testPutTrainedModelConfigThatAlreadyExists() throws Exception {
String modelId = "test-put-trained-model-config-exists";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId, 0)));
}
public void testGetTrainedModelConfig() throws Exception {
String modelId = "test-get-trained-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config));
}
public void testGetMissingTrainingModelConfig() throws Exception {
String modelId = "test-get-missing-trained-model-config";
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, 0)));
}
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setDefinition(TreeTests.createRandom())
.setDescription("trained model config for test")
.setModelId(modelId)
.setModelType("binary_decision_tree")
.setModelVersion(modelVersion)
.build(Version.CURRENT);
}
@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}