* [ML] prefer secondary auth headers on evaluate (#59167) We should prefer the secondary auth headers when evaluating a data frame
This commit is contained in:
parent
24c6a30e2b
commit
e343e066fc
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.RequestOptions;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
|
||||
public class ClassificationEvaluationWithSecurityIT extends ESRestTestCase {
|
||||
private static final String BASIC_AUTH_VALUE_SUPER_USER =
|
||||
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
|
||||
}
|
||||
|
||||
private static void setupDataAccessRole(String index) throws IOException {
|
||||
Request request = new Request("PUT", "/_security/role/test_data_access");
|
||||
request.setJsonEntity("{"
|
||||
+ " \"indices\" : ["
|
||||
+ " { \"names\": [\"" + index + "\"], \"privileges\": [\"read\"] }"
|
||||
+ " ]"
|
||||
+ "}");
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
private void setupUser(String user, List<String> roles) throws IOException {
|
||||
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());
|
||||
|
||||
Request request = new Request("PUT", "/_security/user/" + user);
|
||||
request.setJsonEntity("{"
|
||||
+ " \"password\" : \"" + password + "\","
|
||||
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
|
||||
+ "}");
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
public void testEvaluate_withSecurity() throws Exception {
|
||||
String index = "test_data";
|
||||
Request createDoc = new Request("POST", index + "/_doc");
|
||||
createDoc.setJsonEntity(
|
||||
"{\n" +
|
||||
" \"is_outlier\": 0.0,\n" +
|
||||
" \"ml.outlier_score\": 1.0\n" +
|
||||
"}"
|
||||
);
|
||||
client().performRequest(createDoc);
|
||||
Request refreshRequest = new Request("POST", index + "/_refresh");
|
||||
client().performRequest(refreshRequest);
|
||||
setupDataAccessRole(index);
|
||||
setupUser("ml_admin", Collections.singletonList("machine_learning_admin"));
|
||||
setupUser("ml_admin_plus_data", Arrays.asList("machine_learning_admin", "test_data_access"));
|
||||
String mlAdmin = basicAuthHeaderValue("ml_admin", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
String mlAdminPlusData = basicAuthHeaderValue("ml_admin_plus_data", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
Request evaluateRequest = buildRegressionEval(index, mlAdmin, mlAdminPlusData);
|
||||
client().performRequest(evaluateRequest);
|
||||
|
||||
Request failingRequest = buildRegressionEval(index, mlAdminPlusData, mlAdmin);
|
||||
expectThrows(ResponseException.class, () -> client().performRequest(failingRequest));
|
||||
}
|
||||
|
||||
private static Request buildRegressionEval(String index, String primaryHeader, String secondaryHeader) {
|
||||
Request evaluateRequest = new Request("POST", "_ml/data_frame/_evaluate");
|
||||
evaluateRequest.setJsonEntity(
|
||||
"{\n" +
|
||||
" \"index\": \"" + index + "\",\n" +
|
||||
" \"evaluation\": {\n" +
|
||||
" \"regression\": {\n" +
|
||||
" \"actual_field\": \"is_outlier\",\n" +
|
||||
" \"predicted_field\": \"ml.outlier_score\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}\n"
|
||||
);
|
||||
RequestOptions.Builder options = evaluateRequest.getOptions().toBuilder();
|
||||
options.addHeader("Authorization", primaryHeader);
|
||||
options.addHeader("es-secondary-authorization", secondaryHeader);
|
||||
evaluateRequest.setOptions(options);
|
||||
return evaluateRequest;
|
||||
}
|
||||
}
|
|
@ -13,19 +13,23 @@ import org.elasticsearch.action.support.HandledTransportAction;
|
|||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackSettings;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.security.SecurityContext;
|
||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
|
||||
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
|
||||
|
||||
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
||||
EvaluateDataFrameAction.Response> {
|
||||
|
@ -33,9 +37,11 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
private final ThreadPool threadPool;
|
||||
private final Client client;
|
||||
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
|
||||
private final SecurityContext securityContext;
|
||||
|
||||
@Inject
|
||||
public TransportEvaluateDataFrameAction(TransportService transportService,
|
||||
Settings settings,
|
||||
ActionFilters actionFilters,
|
||||
ThreadPool threadPool,
|
||||
Client client,
|
||||
|
@ -43,6 +49,8 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
||||
this.threadPool = threadPool;
|
||||
this.client = client;
|
||||
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
|
||||
new SecurityContext(settings, threadPool.getThreadContext()) : null;
|
||||
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
|
||||
}
|
||||
|
@ -66,7 +74,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
|
||||
// Create an immutable collection of parameters to be used by evaluation metrics.
|
||||
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
|
||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
|
||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request, securityContext);
|
||||
evaluationExecutor.execute(resultsListener);
|
||||
}
|
||||
|
||||
|
@ -89,13 +97,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
private final EvaluationParameters parameters;
|
||||
private final EvaluateDataFrameAction.Request request;
|
||||
private final Evaluation evaluation;
|
||||
private final SecurityContext securityContext;
|
||||
|
||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
|
||||
EvaluationExecutor(ThreadPool threadPool,
|
||||
Client client,
|
||||
EvaluationParameters parameters,
|
||||
EvaluateDataFrameAction.Request request,
|
||||
SecurityContext securityContext) {
|
||||
super(threadPool.generic(), unused -> true, unused -> true);
|
||||
this.client = client;
|
||||
this.parameters = parameters;
|
||||
this.request = request;
|
||||
this.evaluation = request.getEvaluation();
|
||||
this.securityContext = securityContext;
|
||||
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
||||
add(nextTask());
|
||||
}
|
||||
|
@ -104,18 +118,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
return listener -> {
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
|
||||
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
||||
client.execute(
|
||||
SearchAction.INSTANCE,
|
||||
searchRequest,
|
||||
ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
evaluation.process(searchResponse);
|
||||
if (evaluation.hasAllResults() == false) {
|
||||
add(nextTask());
|
||||
}
|
||||
listener.onResponse(null);
|
||||
},
|
||||
listener::onFailure));
|
||||
useSecondaryAuthIfAvailable(securityContext,
|
||||
() -> client.execute(
|
||||
SearchAction.INSTANCE,
|
||||
searchRequest,
|
||||
ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
evaluation.process(searchResponse);
|
||||
if (evaluation.hasAllResults() == false) {
|
||||
add(nextTask());
|
||||
}
|
||||
listener.onResponse(null);
|
||||
},
|
||||
listener::onFailure)));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue