[7.x] [ML] prefer secondary auth headers on evaluate (#59167) (#59183)

* [ML] prefer secondary auth headers on evaluate (#59167)

We should prefer the secondary auth headers when evaluating a data frame
This commit is contained in:
Benjamin Trent 2020-07-07 15:34:47 -04:00 committed by GitHub
parent 24c6a30e2b
commit e343e066fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 14 deletions

View File

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
public class ClassificationEvaluationWithSecurityIT extends ESRestTestCase {
private static final String BASIC_AUTH_VALUE_SUPER_USER =
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
@Override
protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
}
private static void setupDataAccessRole(String index) throws IOException {
Request request = new Request("PUT", "/_security/role/test_data_access");
request.setJsonEntity("{"
+ " \"indices\" : ["
+ " { \"names\": [\"" + index + "\"], \"privileges\": [\"read\"] }"
+ " ]"
+ "}");
client().performRequest(request);
}
private void setupUser(String user, List<String> roles) throws IOException {
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());
Request request = new Request("PUT", "/_security/user/" + user);
request.setJsonEntity("{"
+ " \"password\" : \"" + password + "\","
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
+ "}");
client().performRequest(request);
}
public void testEvaluate_withSecurity() throws Exception {
String index = "test_data";
Request createDoc = new Request("POST", index + "/_doc");
createDoc.setJsonEntity(
"{\n" +
" \"is_outlier\": 0.0,\n" +
" \"ml.outlier_score\": 1.0\n" +
"}"
);
client().performRequest(createDoc);
Request refreshRequest = new Request("POST", index + "/_refresh");
client().performRequest(refreshRequest);
setupDataAccessRole(index);
setupUser("ml_admin", Collections.singletonList("machine_learning_admin"));
setupUser("ml_admin_plus_data", Arrays.asList("machine_learning_admin", "test_data_access"));
String mlAdmin = basicAuthHeaderValue("ml_admin", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
String mlAdminPlusData = basicAuthHeaderValue("ml_admin_plus_data", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
Request evaluateRequest = buildRegressionEval(index, mlAdmin, mlAdminPlusData);
client().performRequest(evaluateRequest);
Request failingRequest = buildRegressionEval(index, mlAdminPlusData, mlAdmin);
expectThrows(ResponseException.class, () -> client().performRequest(failingRequest));
}
private static Request buildRegressionEval(String index, String primaryHeader, String secondaryHeader) {
Request evaluateRequest = new Request("POST", "_ml/data_frame/_evaluate");
evaluateRequest.setJsonEntity(
"{\n" +
" \"index\": \"" + index + "\",\n" +
" \"evaluation\": {\n" +
" \"regression\": {\n" +
" \"actual_field\": \"is_outlier\",\n" +
" \"predicted_field\": \"ml.outlier_score\"\n" +
" }\n" +
" }\n" +
"}\n"
);
RequestOptions.Builder options = evaluateRequest.getOptions().toBuilder();
options.addHeader("Authorization", primaryHeader);
options.addHeader("es-secondary-authorization", secondaryHeader);
evaluateRequest.setOptions(options);
return evaluateRequest;
}
}

View File

@ -13,19 +13,23 @@ import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
EvaluateDataFrameAction.Response> {
@ -33,9 +37,11 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
private final ThreadPool threadPool;
private final Client client;
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
private final SecurityContext securityContext;
@Inject
public TransportEvaluateDataFrameAction(TransportService transportService,
Settings settings,
ActionFilters actionFilters,
ThreadPool threadPool,
Client client,
@ -43,6 +49,8 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
this.threadPool = threadPool;
this.client = client;
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
new SecurityContext(settings, threadPool.getThreadContext()) : null;
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
}
@ -66,7 +74,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
// Create an immutable collection of parameters to be used by evaluation metrics.
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request, securityContext);
evaluationExecutor.execute(resultsListener);
}
@ -89,13 +97,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
private final EvaluationParameters parameters;
private final EvaluateDataFrameAction.Request request;
private final Evaluation evaluation;
private final SecurityContext securityContext;
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
EvaluationExecutor(ThreadPool threadPool,
Client client,
EvaluationParameters parameters,
EvaluateDataFrameAction.Request request,
SecurityContext securityContext) {
super(threadPool.generic(), unused -> true, unused -> true);
this.client = client;
this.parameters = parameters;
this.request = request;
this.evaluation = request.getEvaluation();
this.securityContext = securityContext;
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
add(nextTask());
}
@ -104,18 +118,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
return listener -> {
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
client.execute(
SearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(
searchResponse -> {
evaluation.process(searchResponse);
if (evaluation.hasAllResults() == false) {
add(nextTask());
}
listener.onResponse(null);
},
listener::onFailure));
useSecondaryAuthIfAvailable(securityContext,
() -> client.execute(
SearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(
searchResponse -> {
evaluation.process(searchResponse);
if (evaluation.hasAllResults() == false) {
add(nextTask());
}
listener.onResponse(null);
},
listener::onFailure)));
};
}
}