Extract ClientHelper.filterSecurityHeaders method and use it in ML code (#58447) (#58459)

This commit is contained in:
Przemysław Witek 2020-06-23 22:18:39 +02:00 committed by GitHub
parent b40c27698f
commit 4e4ca6ac25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 97 additions and 57 deletions

View File

@ -5,11 +5,11 @@
*/
package org.elasticsearch.xpack.core;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
@ -33,9 +34,23 @@ public final class ClientHelper {
/**
* List of headers that are related to security
*/
public static final Set<String> SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationServiceField.RUN_AS_USER_HEADER,
public static final Set<String> SECURITY_HEADER_FILTERS =
Sets.newHashSet(
AuthenticationServiceField.RUN_AS_USER_HEADER,
AuthenticationField.AUTHENTICATION_KEY);
/**
* Leaves only headers that are related to security and filters out the rest.
*
* @param headers Headers to be filtered
* @return A portion of entries that are related to security
*/
public static Map<String, String> filterSecurityHeaders(Map<String, String> headers) {
return Objects.requireNonNull(headers).entrySet().stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
/**
* .
* @deprecated use ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME
@ -109,8 +124,7 @@ public final class ClientHelper {
*/
public static <T extends ActionResponse> T executeWithHeaders(Map<String, String> headers, String origin, Client client,
Supplier<T> supplier) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
// no security headers, we will have to use the xpack internal user for
// our execution by specifying the origin
@ -145,8 +159,7 @@ public final class ClientHelper {
void executeWithHeadersAsync(Map<String, String> headers, String origin, Client client, ActionType<Response> action, Request request,
ActionListener<Response> listener) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
final ThreadContext threadContext = client.threadPool().getThreadContext();

View File

@ -22,7 +22,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator;
@ -42,7 +41,8 @@ import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
public class MlMetadata implements XPackPlugin.XPackMetadataCustom {
@ -315,12 +315,9 @@ public class MlMetadata implements XPackPlugin.XPackMetadataCustom {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
datafeedConfig = builder.build();
datafeedConfig = new DatafeedConfig.Builder(datafeedConfig)
.setHeaders(filterSecurityHeaders(headers))
.build();
}
datafeeds.put(datafeedConfig.getId(), datafeedConfig);

View File

@ -659,98 +659,113 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
this.indicesOptions = config.indicesOptions;
}
public void setId(String datafeedId) {
public Builder setId(String datafeedId) {
id = ExceptionsHelper.requireNonNull(datafeedId, ID.getPreferredName());
return this;
}
public String getId() {
return id;
}
public void setJobId(String jobId) {
public Builder setJobId(String jobId) {
this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName());
return this;
}
public void setHeaders(Map<String, String> headers) {
public Builder setHeaders(Map<String, String> headers) {
this.headers = ExceptionsHelper.requireNonNull(headers, HEADERS.getPreferredName());
return this;
}
public void setIndices(List<String> indices) {
public Builder setIndices(List<String> indices) {
this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName());
return this;
}
public void setQueryDelay(TimeValue queryDelay) {
public Builder setQueryDelay(TimeValue queryDelay) {
TimeUtils.checkNonNegativeMultiple(queryDelay, TimeUnit.MILLISECONDS, QUERY_DELAY);
this.queryDelay = queryDelay;
return this;
}
public void setFrequency(TimeValue frequency) {
public Builder setFrequency(TimeValue frequency) {
TimeUtils.checkPositiveMultiple(frequency, TimeUnit.SECONDS, FREQUENCY);
this.frequency = frequency;
return this;
}
public void setQueryProvider(QueryProvider queryProvider) {
public Builder setQueryProvider(QueryProvider queryProvider) {
this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName());
return this;
}
// For testing only
public void setParsedQuery(QueryBuilder queryBuilder) {
public Builder setParsedQuery(QueryBuilder queryBuilder) {
try {
this.queryProvider = ExceptionsHelper.requireNonNull(QueryProvider.fromParsedQuery(queryBuilder), QUERY.getPreferredName());
} catch (IOException exception) {
// eat exception as it should never happen
logger.error("Exception trying to setParsedQuery", exception);
}
return this;
}
// For testing only
public void setParsedAggregations(AggregatorFactories.Builder aggregations) {
public Builder setParsedAggregations(AggregatorFactories.Builder aggregations) {
try {
this.aggProvider = AggProvider.fromParsedAggs(aggregations);
} catch (IOException exception) {
// eat exception as it should never happen
logger.error("Exception trying to setParsedAggregations", exception);
}
return this;
}
private void setAggregationsSafe(AggProvider aggProvider) {
private Builder setAggregationsSafe(AggProvider aggProvider) {
if (this.aggProvider != null) {
throw ExceptionsHelper.badRequestException("Found two aggregation definitions: [aggs] and [aggregations]");
}
this.aggProvider = aggProvider;
return this;
}
public void setAggProvider(AggProvider aggProvider) {
public Builder setAggProvider(AggProvider aggProvider) {
this.aggProvider = aggProvider;
return this;
}
public void setScriptFields(List<SearchSourceBuilder.ScriptField> scriptFields) {
public Builder setScriptFields(List<SearchSourceBuilder.ScriptField> scriptFields) {
List<SearchSourceBuilder.ScriptField> sorted = new ArrayList<>();
for (SearchSourceBuilder.ScriptField scriptField : scriptFields) {
sorted.add(scriptField);
}
sorted.sort(Comparator.comparing(SearchSourceBuilder.ScriptField::fieldName));
this.scriptFields = sorted;
return this;
}
public void setScrollSize(int scrollSize) {
public Builder setScrollSize(int scrollSize) {
if (scrollSize < 0) {
String msg = Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE,
DatafeedConfig.SCROLL_SIZE.getPreferredName(), scrollSize);
throw ExceptionsHelper.badRequestException(msg);
}
this.scrollSize = scrollSize;
return this;
}
public void setChunkingConfig(ChunkingConfig chunkingConfig) {
public Builder setChunkingConfig(ChunkingConfig chunkingConfig) {
this.chunkingConfig = chunkingConfig;
return this;
}
public void setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) {
public Builder setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) {
this.delayedDataCheckConfig = delayedDataCheckConfig;
return this;
}
public void setMaxEmptySearches(int maxEmptySearches) {
public Builder setMaxEmptySearches(int maxEmptySearches) {
if (maxEmptySearches == -1) {
this.maxEmptySearches = null;
} else if (maxEmptySearches <= 0) {
@ -760,6 +775,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
} else {
this.maxEmptySearches = maxEmptySearches;
}
return this;
}
public Builder setIndicesOptions(IndicesOptions indicesOptions) {

View File

@ -24,7 +24,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -38,7 +37,8 @@ import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
/**
@ -389,10 +389,7 @@ public class DatafeedUpdate implements Writeable, ToXContentObject {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
builder.setHeaders(filterSecurityHeaders(headers));
}
return builder.build();

View File

@ -31,6 +31,8 @@ import java.util.concurrent.CountDownLatch;
import java.util.function.Consumer;
import static org.elasticsearch.xpack.core.ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
@ -311,4 +313,30 @@ public class ClientHelperTests extends ESTestCase {
return client.search(new SearchRequest()).actionGet();
});
}
public void testFilterSecurityHeaders() {
{ // Empty map
assertThat(ClientHelper.filterSecurityHeaders(Collections.emptyMap()), is(anEmptyMap()));
}
{ // Singleton map with no security-related headers
assertThat(ClientHelper.filterSecurityHeaders(Collections.singletonMap("non-security-header", "value")), is(anEmptyMap()));
}
{ // Singleton map with a security-related header
assertThat(
ClientHelper.filterSecurityHeaders(Collections.singletonMap(AuthenticationServiceField.RUN_AS_USER_HEADER, "value")),
hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value"));
}
{ // Map with 3 headers out of which only 1 is security-related
Map<String, String> headers = new HashMap<>();
headers.put("non-security-header-1", "value-1");
headers.put(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2");
headers.put("other-non-security-header", "value-3");
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(headers);
assertThat(filteredHeaders, is(aMapWithSize(1)));
assertThat(filteredHeaders, hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2"));
}
{ // null
expectThrows(NullPointerException.class, () -> ClientHelper.filterSecurityHeaders(null));
}
}
}

View File

@ -16,7 +16,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction;
import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig;
@ -33,10 +32,10 @@ import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
public class TransportPreviewDatafeedAction extends HandledTransportAction<PreviewDatafeedAction.Request, PreviewDatafeedAction.Response> {
@ -74,10 +73,7 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction<Previ
jobBuilder -> {
DatafeedConfig.Builder previewDatafeed = buildPreviewDatafeed(datafeedConfig);
useSecondaryAuthIfAvailable(securityContext, () -> {
Map<String, String> headers = threadPool.getThreadContext().getHeaders().entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
previewDatafeed.setHeaders(headers);
previewDatafeed.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()));
jobResultsProvider.datafeedTimingStats(
jobBuilder.getId(),
timingStats -> {

View File

@ -45,7 +45,6 @@ import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@ -69,10 +68,10 @@ import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
/**
* This class implements CRUD operation for the
@ -108,12 +107,9 @@ public class DatafeedConfigProvider {
if (headers.isEmpty() == false) {
// Filter any values in headers that aren't security fields
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(config);
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
config = builder.build();
config = new DatafeedConfig.Builder(config)
.setHeaders(filterSecurityHeaders(headers))
.build();
}
final String datafeedId = config.getId();

View File

@ -30,7 +30,6 @@ import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -51,6 +50,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
public class DataFrameAnalyticsConfigProvider {
@ -73,12 +73,9 @@ public class DataFrameAnalyticsConfigProvider {
if (headers.isEmpty() == false) {
// Filter any values in headers that aren't security fields
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(config);
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
config = builder.build();
config = new DataFrameAnalyticsConfig.Builder(config)
.setHeaders(filterSecurityHeaders(headers))
.build();
}
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS));