[ML][Inference] fixing ingest IT tests (#51267) (#51311)

Converts InferenceIngestIT into a `ESRestTestCase`.

closes #51201
This commit is contained in:
Benjamin Trent 2020-01-22 09:50:17 -05:00 committed by GitHub
parent b38fdf9f94
commit 2a73e849d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 111 additions and 150 deletions

View File

@ -5,141 +5,109 @@
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.test.ExternalTestCluster;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
/**
* This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
* Specifically, ensuring the accounting breaker has been reset.
* It has to do with `_simulate` not anything really to do with the ML code
*/
public class InferenceIngestIT extends ESRestTestCase {
private static final String BASIC_AUTH_VALUE_SUPER_USER =
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
@Before
public void createBothModels() throws Exception {
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
Request request = new Request("PUT", "_ml/inference/test_classification");
request.setJsonEntity(CLASSIFICATION_CONFIG);
client().performRequest(request);
request = new Request("PUT", "_ml/inference/test_regression");
request.setJsonEntity(REGRESSION_CONFIG);
client().performRequest(request);
}
@Override
protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
}
@After
public void deleteBothModels() {
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
public void cleanUpData() throws Exception {
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
ESRestTestCase.waitForPendingTasks(adminClient());
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
}
public void testPipelineCreationAndDeletion() throws Exception {
for (int i = 0; i < 10; i++) {
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_classification_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_regression_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
is(true));
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
}
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
for (int i = 0; i < 10; i++) {
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(generateSourceDoc())
.setPipeline("simple_classification_pipeline")
.get();
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(generateSourceDoc())
.setPipeline("simple_regression_pipeline")
.get();
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
}
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
is(true));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
is(true));
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
.source(new SearchSourceBuilder()
.size(0)
.trackTotalHits(true)
.query(QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))))).get().getHits().getTotalHits().value,
equalTo(20L));
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
.source(new SearchSourceBuilder()
.size(0)
.trackTotalHits(true)
.query(QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))))
.get()
.getHits()
.getTotalHits()
.value,
equalTo(20L));
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
}
public void testSimulate() {
public void testSimulate() throws IOException {
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
@ -181,15 +149,10 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" +
"}";
SimulatePipelineResponse response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
assertThat(baseResult.getIngestDocument().getFieldValue("ml.regression.predicted_value", Double.class), equalTo(1.0));
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.predicted_value", String.class),
equalTo("second"));
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.result_class_prob", List.class).size(),
equalTo(2));
Response response = client().performRequest(simulateRequest(source));
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"predicted_value\":1.0"));
String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" +
@ -217,15 +180,13 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" +
"}";
response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
response = client().performRequest(simulateRequest(sourceWithMissingModel));
responseString = EntityUtils.toString(response.getEntity());
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(),
containsString("Could not find trained model [test_classification_missing]"));
assertThat(responseString, containsString("Could not find trained model [test_classification_missing]"));
}
public void testSimulateLangIdent() {
public void testSimulateLangIdent() throws IOException {
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
@ -244,11 +205,43 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" +
"}";
SimulatePipelineResponse response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
Response response = client().performRequest(simulateRequest(source));
assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
}
private static Request simulateRequest(String jsonEntity) {
Request request = new Request("POST", "_ingest/pipeline/_simulate");
request.setJsonEntity(jsonEntity);
return request;
}
private static Request indexRequest(String index, String pipeline, Map<String, Object> doc) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(doc)) {
return indexRequest(index,
pipeline,
XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON));
}
}
private static Request indexRequest(String index, String pipeline, String doc) {
Request request = new Request("POST", index + "/_doc?pipeline=" + pipeline);
request.setJsonEntity(doc);
return request;
}
private static Request putPipeline(String pipelineId, String pipelineDefinition) {
Request request = new Request("PUT", "_ingest/pipeline/" + pipelineId);
request.setJsonEntity(pipelineDefinition);
return request;
}
private static Request searchRequest(String index, QueryBuilder queryBuilder) throws IOException {
BytesReference reference = XContentHelper.toXContent(queryBuilder, XContentType.JSON, false);
String queryJson = XContentHelper.convertToJson(reference, false, XContentType.JSON);
String json = "{\"query\": " + queryJson + "}";
Request request = new Request("GET", index + "/_search?track_total_hits=true");
request.setJsonEntity(json);
return request;
}
private Map<String, Object> generateSourceDoc() {
@ -380,16 +373,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
"}";
private static final String REGRESSION_CONFIG = "{" +
" \"model_id\": \"test_regression\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"definition\": " + REGRESSION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0" +
" \"definition\": " + REGRESSION_DEFINITION +
"}";
private static final String CLASSIFICATION_DEFINITION = "{" +
@ -512,24 +498,6 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }\n" +
"}";
private TrainedModelConfig buildClassificationModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
private TrainedModelConfig buildRegressionModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(REGRESSION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
@ -537,16 +505,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
private static final String CLASSIFICATION_CONFIG = "" +
"{\n" +
" \"model_id\": \"test_classification\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"es_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION +
"}";
private static final String CLASSIFICATION_PIPELINE = "{" +