mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-08 22:14:59 +00:00
* [ML] prefer secondary auth headers on evaluate (#59167) We should prefer the secondary auth headers when evaluating a data frame
This commit is contained in:
parent
24c6a30e2b
commit
e343e066fc
@ -0,0 +1,98 @@
|
|||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.Request;
|
||||||
|
import org.elasticsearch.client.RequestOptions;
|
||||||
|
import org.elasticsearch.client.ResponseException;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||||
|
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||||
|
|
||||||
|
public class ClassificationEvaluationWithSecurityIT extends ESRestTestCase {
|
||||||
|
private static final String BASIC_AUTH_VALUE_SUPER_USER =
|
||||||
|
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Settings restClientSettings() {
|
||||||
|
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void setupDataAccessRole(String index) throws IOException {
|
||||||
|
Request request = new Request("PUT", "/_security/role/test_data_access");
|
||||||
|
request.setJsonEntity("{"
|
||||||
|
+ " \"indices\" : ["
|
||||||
|
+ " { \"names\": [\"" + index + "\"], \"privileges\": [\"read\"] }"
|
||||||
|
+ " ]"
|
||||||
|
+ "}");
|
||||||
|
client().performRequest(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupUser(String user, List<String> roles) throws IOException {
|
||||||
|
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());
|
||||||
|
|
||||||
|
Request request = new Request("PUT", "/_security/user/" + user);
|
||||||
|
request.setJsonEntity("{"
|
||||||
|
+ " \"password\" : \"" + password + "\","
|
||||||
|
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
|
||||||
|
+ "}");
|
||||||
|
client().performRequest(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testEvaluate_withSecurity() throws Exception {
|
||||||
|
String index = "test_data";
|
||||||
|
Request createDoc = new Request("POST", index + "/_doc");
|
||||||
|
createDoc.setJsonEntity(
|
||||||
|
"{\n" +
|
||||||
|
" \"is_outlier\": 0.0,\n" +
|
||||||
|
" \"ml.outlier_score\": 1.0\n" +
|
||||||
|
"}"
|
||||||
|
);
|
||||||
|
client().performRequest(createDoc);
|
||||||
|
Request refreshRequest = new Request("POST", index + "/_refresh");
|
||||||
|
client().performRequest(refreshRequest);
|
||||||
|
setupDataAccessRole(index);
|
||||||
|
setupUser("ml_admin", Collections.singletonList("machine_learning_admin"));
|
||||||
|
setupUser("ml_admin_plus_data", Arrays.asList("machine_learning_admin", "test_data_access"));
|
||||||
|
String mlAdmin = basicAuthHeaderValue("ml_admin", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||||
|
String mlAdminPlusData = basicAuthHeaderValue("ml_admin_plus_data", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||||
|
Request evaluateRequest = buildRegressionEval(index, mlAdmin, mlAdminPlusData);
|
||||||
|
client().performRequest(evaluateRequest);
|
||||||
|
|
||||||
|
Request failingRequest = buildRegressionEval(index, mlAdminPlusData, mlAdmin);
|
||||||
|
expectThrows(ResponseException.class, () -> client().performRequest(failingRequest));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Request buildRegressionEval(String index, String primaryHeader, String secondaryHeader) {
|
||||||
|
Request evaluateRequest = new Request("POST", "_ml/data_frame/_evaluate");
|
||||||
|
evaluateRequest.setJsonEntity(
|
||||||
|
"{\n" +
|
||||||
|
" \"index\": \"" + index + "\",\n" +
|
||||||
|
" \"evaluation\": {\n" +
|
||||||
|
" \"regression\": {\n" +
|
||||||
|
" \"actual_field\": \"is_outlier\",\n" +
|
||||||
|
" \"predicted_field\": \"ml.outlier_score\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}\n"
|
||||||
|
);
|
||||||
|
RequestOptions.Builder options = evaluateRequest.getOptions().toBuilder();
|
||||||
|
options.addHeader("Authorization", primaryHeader);
|
||||||
|
options.addHeader("es-secondary-authorization", secondaryHeader);
|
||||||
|
evaluateRequest.setOptions(options);
|
||||||
|
return evaluateRequest;
|
||||||
|
}
|
||||||
|
}
|
@ -13,19 +13,23 @@ import org.elasticsearch.action.support.HandledTransportAction;
|
|||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.cluster.service.ClusterService;
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
import org.elasticsearch.common.inject.Inject;
|
import org.elasticsearch.common.inject.Inject;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.tasks.Task;
|
import org.elasticsearch.tasks.Task;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
|
import org.elasticsearch.xpack.core.XPackSettings;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
import org.elasticsearch.xpack.core.security.SecurityContext;
|
||||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
|
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
|
||||||
|
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
|
||||||
|
|
||||||
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
||||||
EvaluateDataFrameAction.Response> {
|
EvaluateDataFrameAction.Response> {
|
||||||
@ -33,9 +37,11 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||||||
private final ThreadPool threadPool;
|
private final ThreadPool threadPool;
|
||||||
private final Client client;
|
private final Client client;
|
||||||
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
|
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
|
||||||
|
private final SecurityContext securityContext;
|
||||||
|
|
||||||
@Inject
|
@Inject
|
||||||
public TransportEvaluateDataFrameAction(TransportService transportService,
|
public TransportEvaluateDataFrameAction(TransportService transportService,
|
||||||
|
Settings settings,
|
||||||
ActionFilters actionFilters,
|
ActionFilters actionFilters,
|
||||||
ThreadPool threadPool,
|
ThreadPool threadPool,
|
||||||
Client client,
|
Client client,
|
||||||
@ -43,6 +49,8 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||||||
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
||||||
this.threadPool = threadPool;
|
this.threadPool = threadPool;
|
||||||
this.client = client;
|
this.client = client;
|
||||||
|
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
|
||||||
|
new SecurityContext(settings, threadPool.getThreadContext()) : null;
|
||||||
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
|
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
|
||||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
|
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
|
||||||
}
|
}
|
||||||
@ -66,7 +74,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||||||
|
|
||||||
// Create an immutable collection of parameters to be used by evaluation metrics.
|
// Create an immutable collection of parameters to be used by evaluation metrics.
|
||||||
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
|
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
|
||||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
|
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request, securityContext);
|
||||||
evaluationExecutor.execute(resultsListener);
|
evaluationExecutor.execute(resultsListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,13 +97,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||||||
private final EvaluationParameters parameters;
|
private final EvaluationParameters parameters;
|
||||||
private final EvaluateDataFrameAction.Request request;
|
private final EvaluateDataFrameAction.Request request;
|
||||||
private final Evaluation evaluation;
|
private final Evaluation evaluation;
|
||||||
|
private final SecurityContext securityContext;
|
||||||
|
|
||||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
|
EvaluationExecutor(ThreadPool threadPool,
|
||||||
|
Client client,
|
||||||
|
EvaluationParameters parameters,
|
||||||
|
EvaluateDataFrameAction.Request request,
|
||||||
|
SecurityContext securityContext) {
|
||||||
super(threadPool.generic(), unused -> true, unused -> true);
|
super(threadPool.generic(), unused -> true, unused -> true);
|
||||||
this.client = client;
|
this.client = client;
|
||||||
this.parameters = parameters;
|
this.parameters = parameters;
|
||||||
this.request = request;
|
this.request = request;
|
||||||
this.evaluation = request.getEvaluation();
|
this.evaluation = request.getEvaluation();
|
||||||
|
this.securityContext = securityContext;
|
||||||
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
||||||
add(nextTask());
|
add(nextTask());
|
||||||
}
|
}
|
||||||
@ -104,18 +118,19 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||||||
return listener -> {
|
return listener -> {
|
||||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
|
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
|
||||||
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
||||||
client.execute(
|
useSecondaryAuthIfAvailable(securityContext,
|
||||||
SearchAction.INSTANCE,
|
() -> client.execute(
|
||||||
searchRequest,
|
SearchAction.INSTANCE,
|
||||||
ActionListener.wrap(
|
searchRequest,
|
||||||
searchResponse -> {
|
ActionListener.wrap(
|
||||||
evaluation.process(searchResponse);
|
searchResponse -> {
|
||||||
if (evaluation.hasAllResults() == false) {
|
evaluation.process(searchResponse);
|
||||||
add(nextTask());
|
if (evaluation.hasAllResults() == false) {
|
||||||
}
|
add(nextTask());
|
||||||
listener.onResponse(null);
|
}
|
||||||
},
|
listener.onResponse(null);
|
||||||
listener::onFailure));
|
},
|
||||||
|
listener::onFailure)));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user