[ML][Inference] Add support for models shipped as resources (#50680) (#50700)

This adds support for models that are shipped as resources in the ML plugin. The first of which is the `lang_ident` model.
This commit is contained in:
Benjamin Trent 2020-01-07 09:21:59 -05:00 committed by GitHub
parent 98ca9500e8
commit 060e0a6277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 290 additions and 97 deletions

View File

@ -2198,8 +2198,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
GetTrainedModelsResponse getTrainedModelsResponse = execute( GetTrainedModelsResponse getTrainedModelsResponse = execute(
GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(), GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(),
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels)); assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels + 1));
assertThat(getTrainedModelsResponse.getCount(), equalTo(5L)); assertThat(getTrainedModelsResponse.getCount(), equalTo(5L + 1));
} }
{ {
GetTrainedModelsResponse getTrainedModelsResponse = execute( GetTrainedModelsResponse getTrainedModelsResponse = execute(
@ -2222,7 +2222,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
public void testGetTrainedModelsStats() throws Exception { public void testGetTrainedModelsStats() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String modelIdPrefix = "get-trained-model-stats-"; String modelIdPrefix = "a-get-trained-model-stats-";
int numberOfModels = 5; int numberOfModels = 5;
for (int i = 0; i < numberOfModels; ++i) { for (int i = 0; i < numberOfModels; ++i) {
String modelId = modelIdPrefix + i; String modelId = modelIdPrefix + i;
@ -2254,8 +2254,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute( GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(), GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync); machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels)); assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels + 1));
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L)); assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L + 1));
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1)); assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0)); assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
} }

View File

@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -98,6 +99,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
pipelineCount = in.readVInt(); pipelineCount = in.readVInt();
} }
public String getModelId() {
return modelId;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
@ -186,6 +191,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
0 : 0 :
ingestStats.getPipelineStats().size())); ingestStats.getPipelineStats().size()));
}); });
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD)); return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
} }
} }

View File

@ -409,6 +409,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder clearDefinition() {
this.definition = null;
return this;
}
private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) { private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) {
if (parsedTrainedModel == null) { if (parsedTrainedModel == null) {
return this; return this;

View File

@ -87,10 +87,12 @@ public final class Messages {
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String INFERENCE_CANNOT_DELETE_MODEL =
"Unable to delete model [{0}]";
public static final String MODEL_DEFINITION_TRUNCATED = public static final String MODEL_DEFINITION_TRUNCATED =
"Model definition truncated. Unable to deserialize trained model definition [{0}]"; "Model definition truncated. Unable to deserialize trained model definition [{0}]";
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model"; "Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing"; public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";

View File

@ -233,6 +233,32 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
containsString("Could not find trained model [test_classification_missing]")); containsString("Could not find trained model [test_classification_missing]"));
} }
public void testSimulateLangIdent() {
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"inference_config\": {\"classification\":{}},\n" +
" \"model_id\": \"lang_ident_model_1\",\n" +
" \"field_mappings\": {}\n" +
" }\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"text\": \"this is some plain text.\"\n" +
" }}]\n" +
"}";
SimulatePipelineResponse response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
}
private Map<String, Object> generateSourceDoc() { private Map<String, Object> generateSourceDoc() {
return new HashMap<String, Object>(){{ return new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male")); put("col1", randomFrom("female", "male"));

View File

@ -60,8 +60,8 @@ public class TrainedModelIT extends ESRestTestCase {
} }
public void testGetTrainedModels() throws IOException { public void testGetTrainedModels() throws IOException {
String modelId = "test_regression_model"; String modelId = "a_test_regression_model";
String modelId2 = "test_regression_model-2"; String modelId2 = "a_test_regression_model-2";
Request model1 = new Request("PUT", Request model1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
model1.setJsonEntity(buildRegressionModel(modelId)); model1.setJsonEntity(buildRegressionModel(modelId));
@ -84,36 +84,36 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
String response = EntityUtils.toString(getModel.getEntity()); String response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, containsString("\"count\":1")); assertThat(response, containsString("\"count\":1"));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression*")); MachineLearning.BASE_PATH + "inference/a_test_regression*"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
assertThat(response, not(containsString("\"definition\""))); assertThat(response, not(containsString("\"definition\"")));
assertThat(response, containsString("\"count\":2")); assertThat(response, containsString("\"count\":2"));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true")); MachineLearning.BASE_PATH + "inference/a_test_regression_model?human=true&include_model_definition=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\""));
assertThat(response, containsString("\"definition\"")); assertThat(response, containsString("\"definition\""));
assertThat(response, containsString("\"count\":1")); assertThat(response, containsString("\"count\":1"));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model?decompress_definition=false&include_model_definition=true")); MachineLearning.BASE_PATH + "inference/a_test_regression_model?decompress_definition=false&include_model_definition=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"compressed_definition\"")); assertThat(response, containsString("\"compressed_definition\""));
assertThat(response, not(containsString("\"definition\""))); assertThat(response, not(containsString("\"definition\"")));
@ -121,17 +121,17 @@ public class TrainedModelIT extends ESRestTestCase {
ResponseException responseException = expectThrows(ResponseException.class, () -> ResponseException responseException = expectThrows(ResponseException.class, () ->
client().performRequest(new Request("GET", client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true"))); MachineLearning.BASE_PATH + "inference/a_test_regression*?human=true&include_model_definition=true")));
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()), assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)); containsString(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); MachineLearning.BASE_PATH + "inference/a_test_regression_model,a_test_regression_model-2"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
assertThat(response, containsString("\"count\":2")); assertThat(response, containsString("\"count\":2"));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",
@ -149,17 +149,17 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"count\":2")); assertThat(response, containsString("\"count\":3"));
assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\""))); assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model-2\"")));
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1")); getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity()); response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"count\":2")); assertThat(response, containsString("\"count\":3"));
assertThat(response, not(containsString("\"model_id\":\"test_regression_model\""))); assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model\"")));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
} }
public void testDeleteTrainedModels() throws IOException { public void testDeleteTrainedModels() throws IOException {

View File

@ -50,7 +50,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) { if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
listener.onFailure( listener.onFailure(
ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED) ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)
); );
return; return;
} }

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.persistence;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
@ -31,6 +32,7 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
@ -39,6 +41,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexNotFoundException;
@ -65,8 +68,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@ -79,6 +84,10 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FA
public class TrainedModelProvider { public class TrainedModelProvider {
public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
private static final String MODEL_RESOURCE_FILE_EXT = ".json";
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
private final Client client; private final Client client;
private final NamedXContentRegistry xContentRegistry; private final NamedXContentRegistry xContentRegistry;
@ -92,6 +101,12 @@ public class TrainedModelProvider {
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
ActionListener<Boolean> listener) { ActionListener<Boolean> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
return;
}
try { try {
trainedModelConfig.ensureParsedDefinition(xContentRegistry); trainedModelConfig.ensureParsedDefinition(xContentRegistry);
} catch (IOException ex) { } catch (IOException ex) {
@ -185,6 +200,16 @@ public class TrainedModelProvider {
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) { public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
return;
} catch (ElasticsearchException ex) {
listener.onFailure(ex);
return;
}
}
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery() .idsQuery()
.addIds(modelId)); .addIds(modelId));
@ -268,11 +293,29 @@ public class TrainedModelProvider {
.addSort("_index", SortOrder.DESC) .addSort("_index", SortOrder.DESC)
.setQuery(queryBuilder) .setQuery(queryBuilder)
.request(); .request();
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
for(String modelId : modelsAsResource) {
try {
configs.add(loadModelFromResource(modelId, true));
} catch (ElasticsearchException ex) {
listener.onFailure(ex);
return;
}
}
if (modelsInIndex.isEmpty()) {
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
listener.onResponse(configs);
return;
}
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap( ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
searchResponse -> { searchResponse -> {
Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f); Set<String> observedIds = new HashSet<>(
List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length); searchResponse.getHits().getHits().length + modelsAsResource.size(),
1.0f);
observedIds.addAll(modelsAsResource);
for(SearchHit searchHit : searchResponse.getHits().getHits()) { for(SearchHit searchHit : searchResponse.getHits().getHits()) {
try { try {
if (observedIds.contains(searchHit.getId()) == false) { if (observedIds.contains(searchHit.getId()) == false) {
@ -295,6 +338,8 @@ public class TrainedModelProvider {
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
return; return;
} }
// Ensure sorted even with the injection of locally resourced models
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
listener.onResponse(configs); listener.onResponse(configs);
}, },
listener::onFailure listener::onFailure
@ -304,6 +349,10 @@ public class TrainedModelProvider {
} }
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) { public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
listener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, modelId)));
return;
}
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
request.indices(InferenceIndexConstants.INDEX_PATTERN); request.indices(InferenceIndexConstants.INDEX_PATTERN);
@ -360,8 +409,8 @@ public class TrainedModelProvider {
searchRequest, searchRequest,
ActionListener.<SearchResponse>wrap( ActionListener.<SearchResponse>wrap(
response -> { response -> {
Set<String> foundResourceIds = new LinkedHashSet<>(); Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
long totalHitCount = response.getHits().getTotalHits().value; long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
for (SearchHit hit : response.getHits().getHits()) { for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap(); Map<String, Object> docSource = hit.getSourceAsMap();
if (docSource == null) { if (docSource == null) {
@ -386,6 +435,37 @@ public class TrainedModelProvider {
} }
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
if (resource == null) {
logger.error("[{}] presumed stored as a resource but not found", modelId);
throw new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId));
}
try {
BytesReference bytes = Streams.readFully(getClass()
.getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT));
try (XContentParser parser =
XContentHelper.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
bytes,
XContentType.JSON)) {
TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
if (nullOutDefinition) {
builder.clearDefinition();
}
return builder.build();
} catch (IOException ioEx) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
}
} catch (IOException ex) {
String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage();
logger.error(msg, ex);
throw ExceptionsHelper.serverError(msg, ex);
}
}
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
@ -414,6 +494,29 @@ public class TrainedModelProvider {
return boolQuery; return boolQuery;
} }
private Set<String> matchedResourceIds(String[] tokens) {
if (Strings.isAllOrWildcard(tokens)) {
return new HashSet<>(MODELS_STORED_AS_RESOURCE);
}
Set<String> matchedModels = new HashSet<>();
for (String token : tokens) {
if (Regex.isSimpleMatchPattern(token)) {
for (String modelId : MODELS_STORED_AS_RESOURCE) {
if(Regex.simpleMatch(token, modelId)) {
matchedModels.add(modelId);
}
}
} else {
if (MODELS_STORED_AS_RESOURCE.contains(token)) {
matchedModels.add(token);
}
}
}
return matchedModels;
}
private static <T> T handleSearchItem(MultiSearchResponse.Item item, private static <T> T handleSearchItem(MultiSearchResponse.Item item,
String resourceId, String resourceId,
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception { CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.persistence;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
public class TrainedModelProviderTests extends ESTestCase {
public void testDeleteModelStoredAsResource() {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<Boolean> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future);
ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, "lang_ident_model_1")));
}
public void testPutModelThatExistsAsResource() {
TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build();
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<Boolean> future = new PlainActionFuture<>();
trainedModelProvider.storeTrainedModel(config, future);
ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, "lang_ident_model_1")));
}
public void testGetModelThatExistsAsResource() throws Exception {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, true, future);
TrainedModelConfig configWithDefinition = future.actionGet();
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
}
}
public void testGetModelThatExistsAsResourceButIsMissing() {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean()));
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model")));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -5,10 +5,9 @@
*/ */
package org.elasticsearch.xpack.ml.inference.trainedmodels.langident; package org.elasticsearch.xpack.ml.inference.trainedmodels.langident;
import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -16,22 +15,26 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.io.IOException;
import java.nio.file.Files;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.mockito.Mockito.mock;
public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
public void testLangInference() throws Exception { public void testLangInference() throws Exception {
TrainedModelConfig config = getLangIdentModel(); TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
TrainedModelConfig config = future.actionGet();
config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples(); List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
ClassificationConfig classificationConfig = new ClassificationConfig(1); ClassificationConfig classificationConfig = new ClassificationConfig(1);
@ -53,19 +56,6 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
} }
} }
private TrainedModelConfig getLangIdentModel() throws IOException {
String path = "/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json";
try(XContentParser parser =
XContentType.JSON.xContent().createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
Files.newInputStream(getDataPath(path)))) {
TrainedModelConfig config = TrainedModelConfig.fromXContent(parser, true).build();
config.ensureParsedDefinition(xContentRegistry());
return config;
}
}
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());

View File

@ -1,18 +1,3 @@
---
"Test get-all given no trained models exist":
- do:
ml.get_trained_models:
model_id: "_all"
- match: { count: 0 }
- match: { trained_model_configs: [] }
- do:
ml.get_trained_models:
model_id: "*"
- match: { count: 0 }
- match: { trained_model_configs: [] }
--- ---
"Test get given missing trained model": "Test get given missing trained model":
@ -111,3 +96,11 @@
catch: conflict catch: conflict
ml.delete_trained_model: ml.delete_trained_model:
model_id: "used-regression-model" model_id: "used-regression-model"
---
"Test get pre-packaged trained models":
- do:
ml.get_trained_models:
model_id: "_all"
allow_no_match: false
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }

View File

@ -5,17 +5,15 @@ setup:
headers: headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index: index:
id: trained_model_config-unused-regression-model1-0 id: trained_model_config-a-unused-regression-model1-0
index: .ml-inference-000001 index: .ml-inference-000001
body: > body: >
{ {
"model_id": "unused-regression-model1", "model_id": "a-unused-regression-model1",
"created_by": "ml_tests", "created_by": "ml_tests",
"version": "8.0.0", "version": "8.0.0",
"description": "empty model for tests", "description": "empty model for tests",
"create_time": 0, "create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config" "doc_type": "trained_model_config"
} }
@ -23,34 +21,30 @@ setup:
headers: headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index: index:
id: trained_model_config-unused-regression-model-0 id: trained_model_config-a-unused-regression-model-0
index: .ml-inference-000001 index: .ml-inference-000001
body: > body: >
{ {
"model_id": "unused-regression-model", "model_id": "a-unused-regression-model",
"created_by": "ml_tests", "created_by": "ml_tests",
"version": "8.0.0", "version": "8.0.0",
"description": "empty model for tests", "description": "empty model for tests",
"create_time": 0, "create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config" "doc_type": "trained_model_config"
} }
- do: - do:
headers: headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index: index:
id: trained_model_config-used-regression-model-0 id: trained_model_config-a-used-regression-model-0
index: .ml-inference-000001 index: .ml-inference-000001
body: > body: >
{ {
"model_id": "used-regression-model", "model_id": "a-used-regression-model",
"created_by": "ml_tests", "created_by": "ml_tests",
"version": "8.0.0", "version": "8.0.0",
"description": "empty model for tests", "description": "empty model for tests",
"create_time": 0, "create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config" "doc_type": "trained_model_config"
} }
@ -69,7 +63,7 @@ setup:
"processors": [ "processors": [
{ {
"inference" : { "inference" : {
"model_id" : "used-regression-model", "model_id" : "a-used-regression-model",
"inference_config": {"regression": {}}, "inference_config": {"regression": {}},
"target_field": "regression_field", "target_field": "regression_field",
"field_mappings": {} "field_mappings": {}
@ -87,7 +81,7 @@ setup:
"processors": [ "processors": [
{ {
"inference" : { "inference" : {
"model_id" : "used-regression-model", "model_id" : "a-used-regression-model",
"inference_config": {"regression": {}}, "inference_config": {"regression": {}},
"target_field": "regression_field", "target_field": "regression_field",
"field_mappings": {} "field_mappings": {}
@ -125,18 +119,18 @@ setup:
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "unused-regression-model" model_id: "a-unused-regression-model"
- match: { count: 1 } - match: { count: 1 }
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "_all" model_id: "_all"
- match: { count: 3 } - match: { count: 4 }
- match: { trained_model_stats.0.model_id: unused-regression-model } - match: { trained_model_stats.0.model_id: a-unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 } - match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest - is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 } - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 } - match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest - is_false: trained_model_stats.1.ingest
- match: { trained_model_stats.2.pipeline_count: 2 } - match: { trained_model_stats.2.pipeline_count: 2 }
@ -145,11 +139,11 @@ setup:
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "*" model_id: "*"
- match: { count: 3 } - match: { count: 4 }
- match: { trained_model_stats.0.model_id: unused-regression-model } - match: { trained_model_stats.0.model_id: a-unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 } - match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest - is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 } - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 } - match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest - is_false: trained_model_stats.1.ingest
- match: { trained_model_stats.2.pipeline_count: 2 } - match: { trained_model_stats.2.pipeline_count: 2 }
@ -157,40 +151,40 @@ setup:
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "unused-regression-model*" model_id: "a-unused-regression-model*"
- match: { count: 2 } - match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model } - match: { trained_model_stats.0.model_id: a-unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 } - match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest - is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 } - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 } - match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest - is_false: trained_model_stats.1.ingest
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "unused-regression-model*" model_id: "a-unused-regression-model*"
size: 1 size: 1
- match: { count: 2 } - match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model } - match: { trained_model_stats.0.model_id: a-unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 } - match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest - is_false: trained_model_stats.0.ingest
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "unused-regression-model*" model_id: "a-unused-regression-model*"
from: 1 from: 1
size: 1 size: 1
- match: { count: 2 } - match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model1 } - match: { trained_model_stats.0.model_id: a-unused-regression-model1 }
- match: { trained_model_stats.0.pipeline_count: 0 } - match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest - is_false: trained_model_stats.0.ingest
- do: - do:
ml.get_trained_models_stats: ml.get_trained_models_stats:
model_id: "used-regression-model" model_id: "a-used-regression-model"
- match: { count: 1 } - match: { count: 1 }
- match: { trained_model_stats.0.model_id: used-regression-model } - match: { trained_model_stats.0.model_id: a-used-regression-model }
- match: { trained_model_stats.0.pipeline_count: 2 } - match: { trained_model_stats.0.pipeline_count: 2 }
- match: - match:
trained_model_stats.0.ingest.total: trained_model_stats.0.ingest.total: