Converts InferenceIngestIT into a `ESRestTestCase`. closes #51201
This commit is contained in:
parent
b38fdf9f94
commit
2a73e849d6
|
@ -5,141 +5,109 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.test.ExternalTestCluster;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
||||
/**
|
||||
* This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
|
||||
* Specifically, ensuring the accounting breaker has been reset.
|
||||
* It has to do with `_simulate` not anything really to do with the ML code
|
||||
*/
|
||||
public class InferenceIngestIT extends ESRestTestCase {
|
||||
|
||||
private static final String BASIC_AUTH_VALUE_SUPER_USER =
|
||||
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
|
||||
@Before
|
||||
public void createBothModels() throws Exception {
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
|
||||
Request request = new Request("PUT", "_ml/inference/test_classification");
|
||||
request.setJsonEntity(CLASSIFICATION_CONFIG);
|
||||
client().performRequest(request);
|
||||
|
||||
request = new Request("PUT", "_ml/inference/test_regression");
|
||||
request.setJsonEntity(REGRESSION_CONFIG);
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
|
||||
}
|
||||
|
||||
@After
|
||||
public void deleteBothModels() {
|
||||
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
|
||||
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
|
||||
public void cleanUpData() throws Exception {
|
||||
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
|
||||
ESRestTestCase.waitForPendingTasks(adminClient());
|
||||
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
|
||||
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
|
||||
}
|
||||
|
||||
public void testPipelineCreationAndDeletion() throws Exception {
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
|
||||
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(new HashMap<String, Object>(){{
|
||||
put("col1", randomFrom("female", "male"));
|
||||
put("col2", randomFrom("S", "M", "L", "XL"));
|
||||
put("col3", randomFrom("true", "false", "none", "other"));
|
||||
put("col4", randomIntBetween(0, 10));
|
||||
}})
|
||||
.setPipeline("simple_classification_pipeline")
|
||||
.get();
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
|
||||
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(new HashMap<String, Object>(){{
|
||||
put("col1", randomFrom("female", "male"));
|
||||
put("col2", randomFrom("S", "M", "L", "XL"));
|
||||
put("col3", randomFrom("true", "false", "none", "other"));
|
||||
put("col4", randomIntBetween(0, 10));
|
||||
}})
|
||||
.setPipeline("simple_regression_pipeline")
|
||||
.get();
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
|
||||
}
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
|
||||
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
|
||||
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(generateSourceDoc())
|
||||
.setPipeline("simple_classification_pipeline")
|
||||
.get();
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(generateSourceDoc())
|
||||
.setPipeline("simple_regression_pipeline")
|
||||
.get();
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
}
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
|
||||
|
||||
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
|
||||
|
||||
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
|
||||
.source(new SearchSourceBuilder()
|
||||
.size(0)
|
||||
.trackTotalHits(true)
|
||||
.query(QueryBuilders.boolQuery()
|
||||
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))))).get().getHits().getTotalHits().value,
|
||||
equalTo(20L));
|
||||
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
|
||||
|
||||
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
|
||||
.source(new SearchSourceBuilder()
|
||||
.size(0)
|
||||
.trackTotalHits(true)
|
||||
.query(QueryBuilders.boolQuery()
|
||||
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(
|
||||
QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))))
|
||||
.get()
|
||||
.getHits()
|
||||
.getTotalHits()
|
||||
.value,
|
||||
equalTo(20L));
|
||||
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
|
||||
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
|
||||
}
|
||||
|
||||
public void testSimulate() {
|
||||
public void testSimulate() throws IOException {
|
||||
String source = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
|
@ -181,15 +149,10 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" }}]\n" +
|
||||
"}";
|
||||
|
||||
SimulatePipelineResponse response = client().admin().cluster()
|
||||
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get();
|
||||
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("ml.regression.predicted_value", Double.class), equalTo(1.0));
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.predicted_value", String.class),
|
||||
equalTo("second"));
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.result_class_prob", List.class).size(),
|
||||
equalTo(2));
|
||||
Response response = client().performRequest(simulateRequest(source));
|
||||
String responseString = EntityUtils.toString(response.getEntity());
|
||||
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
|
||||
assertThat(responseString, containsString("\"predicted_value\":1.0"));
|
||||
|
||||
String sourceWithMissingModel = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
|
@ -217,15 +180,13 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" }}]\n" +
|
||||
"}";
|
||||
|
||||
response = client().admin().cluster()
|
||||
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get();
|
||||
response = client().performRequest(simulateRequest(sourceWithMissingModel));
|
||||
responseString = EntityUtils.toString(response.getEntity());
|
||||
|
||||
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(),
|
||||
containsString("Could not find trained model [test_classification_missing]"));
|
||||
assertThat(responseString, containsString("Could not find trained model [test_classification_missing]"));
|
||||
}
|
||||
|
||||
public void testSimulateLangIdent() {
|
||||
public void testSimulateLangIdent() throws IOException {
|
||||
String source = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
|
@ -244,11 +205,43 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" }}]\n" +
|
||||
"}";
|
||||
|
||||
SimulatePipelineResponse response = client().admin().cluster()
|
||||
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get();
|
||||
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
|
||||
Response response = client().performRequest(simulateRequest(source));
|
||||
assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
|
||||
}
|
||||
|
||||
private static Request simulateRequest(String jsonEntity) {
|
||||
Request request = new Request("POST", "_ingest/pipeline/_simulate");
|
||||
request.setJsonEntity(jsonEntity);
|
||||
return request;
|
||||
}
|
||||
|
||||
private static Request indexRequest(String index, String pipeline, Map<String, Object> doc) throws IOException {
|
||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(doc)) {
|
||||
return indexRequest(index,
|
||||
pipeline,
|
||||
XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON));
|
||||
}
|
||||
}
|
||||
|
||||
private static Request indexRequest(String index, String pipeline, String doc) {
|
||||
Request request = new Request("POST", index + "/_doc?pipeline=" + pipeline);
|
||||
request.setJsonEntity(doc);
|
||||
return request;
|
||||
}
|
||||
|
||||
private static Request putPipeline(String pipelineId, String pipelineDefinition) {
|
||||
Request request = new Request("PUT", "_ingest/pipeline/" + pipelineId);
|
||||
request.setJsonEntity(pipelineDefinition);
|
||||
return request;
|
||||
}
|
||||
|
||||
private static Request searchRequest(String index, QueryBuilder queryBuilder) throws IOException {
|
||||
BytesReference reference = XContentHelper.toXContent(queryBuilder, XContentType.JSON, false);
|
||||
String queryJson = XContentHelper.convertToJson(reference, false, XContentType.JSON);
|
||||
String json = "{\"query\": " + queryJson + "}";
|
||||
Request request = new Request("GET", index + "/_search?track_total_hits=true");
|
||||
request.setJsonEntity(json);
|
||||
return request;
|
||||
}
|
||||
|
||||
private Map<String, Object> generateSourceDoc() {
|
||||
|
@ -380,16 +373,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
"}";
|
||||
|
||||
private static final String REGRESSION_CONFIG = "{" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for regression\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"definition\": " + REGRESSION_DEFINITION + ","+
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0" +
|
||||
" \"definition\": " + REGRESSION_DEFINITION +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_DEFINITION = "{" +
|
||||
|
@ -512,24 +498,6 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" }\n" +
|
||||
"}";
|
||||
|
||||
private TrainedModelConfig buildClassificationModel() throws IOException {
|
||||
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(CLASSIFICATION_CONFIG),
|
||||
XContentType.JSON)) {
|
||||
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
}
|
||||
|
||||
private TrainedModelConfig buildRegressionModel() throws IOException {
|
||||
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(REGRESSION_CONFIG),
|
||||
XContentType.JSON)) {
|
||||
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
|
@ -537,16 +505,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
|
||||
private static final String CLASSIFICATION_CONFIG = "" +
|
||||
"{\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"es_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0\n" +
|
||||
" \"definition\": " + CLASSIFICATION_DEFINITION +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_PIPELINE = "{" +
|
||||
|
|
Loading…
Reference in New Issue