security: always restore the ThreadContext after invoking an action

This change ensure that the ThreadContext is always restored after an action has been invoked when
going through the SecurityActionFilter and authentication and authorization is enabled.

Original commit: elastic/x-pack-elasticsearch@5da70bd6fa
This commit is contained in:
Jay Modi 2017-01-11 13:41:14 -05:00 committed by GitHub
parent 33e670e3aa
commit c5cab37db6
5 changed files with 319 additions and 12 deletions

View File

@ -100,13 +100,8 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
} }
if (licenseState.isAuthAllowed()) { if (licenseState.isAuthAllowed()) {
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user
// and then a cleanup action is executed (like for search without a scroll)
final boolean restoreOriginalContext = securityContext.getAuthentication() != null;
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action); final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
// we should always restore the original here because we forcefully changed to the system user final ThreadContext.StoredContext toRestore = threadContext.newStoredContext();
final ThreadContext.StoredContext toRestore =
restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {};
final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(threadContext, toRestore, final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(threadContext, toRestore,
ActionListener.wrap(r -> { ActionListener.wrap(r -> {
try { try {
@ -127,7 +122,9 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
} }
}); });
} else { } else {
applyInternal(action, request, authenticatedListener); try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
applyInternal(action, request, authenticatedListener);
}
} }
} catch (Exception e) { } catch (Exception e) {
listener.onFailure(e); listener.onFailure(e);

View File

@ -11,12 +11,9 @@ import org.elasticsearch.xpack.security.authz.IndicesAndAliasesResolver;
import org.elasticsearch.xpack.security.authz.permission.FieldPermissions; import org.elasticsearch.xpack.security.authz.permission.FieldPermissions;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import static java.util.Collections.unmodifiableSet;
/** /**
* Encapsulates the field and document permissions per concrete index based on the current request. * Encapsulates the field and document permissions per concrete index based on the current request.
*/ */

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse; import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.action.termvectors.TermVectorsRequest; import org.elasticsearch.action.termvectors.TermVectorsRequest;
@ -26,6 +27,8 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.children.Children; import org.elasticsearch.search.aggregations.bucket.children.Children;
import org.elasticsearch.search.aggregations.bucket.global.Global; import org.elasticsearch.search.aggregations.bucket.global.Global;
import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortMode;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.SecurityIntegTestCase; import org.elasticsearch.test.SecurityIntegTestCase;
import org.elasticsearch.xpack.XPackSettings; import org.elasticsearch.xpack.XPackSettings;
@ -283,6 +286,85 @@ public class DocumentLevelSecurityTests extends SecurityIntegTestCase {
assertThat(response.getResponses()[0].getResponse().isExists(), is(false)); assertThat(response.getResponses()[0].getResponse().isExists(), is(false));
} }
public void testMSearch() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test1")
.addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text", "id", "type=integer")
);
assertAcked(client().admin().indices().prepareCreate("test2")
.addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text", "id", "type=integer")
);
client().prepareIndex("test1", "type1", "1").setSource("field1", "value1", "id", 1).get();
client().prepareIndex("test1", "type1", "2").setSource("field2", "value2", "id", 2).get();
client().prepareIndex("test1", "type1", "3").setSource("field3", "value3", "id", 3).get();
client().prepareIndex("test2", "type1", "1").setSource("field1", "value1", "id", 1).get();
client().prepareIndex("test2", "type1", "2").setSource("field2", "value2", "id", 2).get();
client().prepareIndex("test2", "type1", "3").setSource("field3", "value3", "id", 3).get();
client().admin().indices().prepareRefresh("test1", "test2").get();
MultiSearchResponse response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(1));
assertFalse(response.getResponses()[1].isFailure());
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(1));
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(2));
assertFalse(response.getResponses()[1].isFailure());
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(2));
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user3", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").addSort(SortBuilders.fieldSort("id").sortMode(SortMode.MIN))
.setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").addSort(SortBuilders.fieldSort("id").sortMode(SortMode.MIN))
.setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(2L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(1));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().get("id"), is(2));
assertFalse(response.getResponses()[1].isFailure());
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(2L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(1));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().get("id"), is(2));
}
public void testTVApi() throws Exception { public void testTVApi() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test") assertAcked(client().admin().indices().prepareCreate("test")
.addMapping("type1", "field1", "type=text,term_vector=with_positions_offsets_payloads", .addMapping("type1", "field1", "type=text,term_vector=with_positions_offsets_payloads",

View File

@ -12,6 +12,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.fieldstats.FieldStatsResponse; import org.elasticsearch.action.fieldstats.FieldStatsResponse;
import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.termvectors.MultiTermVectorsResponse; import org.elasticsearch.action.termvectors.MultiTermVectorsResponse;
import org.elasticsearch.action.termvectors.TermVectorsRequest; import org.elasticsearch.action.termvectors.TermVectorsRequest;
@ -20,6 +21,7 @@ import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.indices.IndicesRequestCache; import org.elasticsearch.indices.IndicesRequestCache;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders;
@ -488,6 +490,155 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
assertThat(response.getResponses()[0].getResponse().getSource().get("field2").toString(), equalTo("value2")); assertThat(response.getResponses()[0].getResponse().getSource().get("field2").toString(), equalTo("value2"));
} }
public void testMSearchApi() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test1")
.addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text")
);
assertAcked(client().admin().indices().prepareCreate("test2")
.addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text")
);
client().prepareIndex("test1", "type1", "1")
.setSource("field1", "value1", "field2", "value2", "field3", "value3").get();
client().prepareIndex("test2", "type1", "1")
.setSource("field1", "value1", "field2", "value2", "field3", "value3").get();
client().admin().indices().prepareRefresh("test1", "test2").get();
// user1 is granted access to field1 only
MultiSearchResponse response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(1));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(1));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
// user2 is granted access to field2 only
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(1));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(1));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
// user3 is granted access to field1 and field2
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user3", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
// user4 is granted access to no fields, so the search response does say the doc exist, but no fields are returned
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user4", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(0));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(0));
// user5 has no field level security configured, so all fields are returned
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user5", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
// user6 has access to field*
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user6", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
// user7 has roles with field level security and without field level security
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user7", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3"));
// user8 has roles with field level security with access to field1 and field2
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user8", USERS_PASSWD)))
.prepareMultiSearch()
.add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery()))
.get();
assertFalse(response.getResponses()[0].isFailure());
assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1"));
assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2"));
}
public void testFieldStatsApi() throws Exception { public void testFieldStatsApi() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test") assertAcked(client().admin().indices().prepareCreate("test")
.addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text") .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text")

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.security.action.filter; package org.elasticsearch.xpack.security.action.filter;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
@ -60,6 +61,7 @@ public class SecurityActionFilterTests extends ESTestCase {
private AuditTrailService auditTrail; private AuditTrailService auditTrail;
private XPackLicenseState licenseState; private XPackLicenseState licenseState;
private SecurityActionFilter filter; private SecurityActionFilter filter;
private ThreadContext threadContext;
private boolean failDestructiveOperations; private boolean failDestructiveOperations;
@Before @Before
@ -72,14 +74,17 @@ public class SecurityActionFilterTests extends ESTestCase {
when(licenseState.isAuthAllowed()).thenReturn(true); when(licenseState.isAuthAllowed()).thenReturn(true);
when(licenseState.isStatsAndHealthAllowed()).thenReturn(true); when(licenseState.isStatsAndHealthAllowed()).thenReturn(true);
ThreadPool threadPool = mock(ThreadPool.class); ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); threadContext = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext);
failDestructiveOperations = randomBoolean(); failDestructiveOperations = randomBoolean();
Settings settings = Settings.builder() Settings settings = Settings.builder()
.put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), failDestructiveOperations).build(); .put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), failDestructiveOperations).build();
DestructiveOperations destructiveOperations = new DestructiveOperations(settings, DestructiveOperations destructiveOperations = new DestructiveOperations(settings,
new ClusterSettings(settings, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING))); new ClusterSettings(settings, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING)));
SecurityContext securityContext = new SecurityContext(settings, threadContext, cryptoService);
filter = new SecurityActionFilter(Settings.EMPTY, authcService, authzService, cryptoService, auditTrail, filter = new SecurityActionFilter(Settings.EMPTY, authcService, authzService, cryptoService, auditTrail,
licenseState, new HashSet<>(), threadPool, mock(SecurityContext.class), destructiveOperations); licenseState, new HashSet<>(), threadPool, securityContext, destructiveOperations);
} }
public void testApply() throws Exception { public void testApply() throws Exception {
@ -108,6 +113,81 @@ public class SecurityActionFilterTests extends ESTestCase {
verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class)); verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class));
} }
public void testApplyRestoresThreadContext() throws Exception {
ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
User user = new User("username", "r1", "r2");
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[3];
assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
threadContext.putTransient(Authentication.AUTHENTICATION_KEY, authentication);
callback.onResponse(authentication);
return Void.TYPE;
}).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
final Role empty = Role.EMPTY;
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[1];
assertEquals(authentication, threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
callback.onResponse(empty);
return Void.TYPE;
}).when(authzService).roles(any(User.class), any(ActionListener.class));
doReturn(request).when(spy(filter)).unsign(user, "_action", request);
assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
filter.apply(task, "_action", request, listener, chain);
assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
verify(authzService).authorize(authentication, "_action", request, empty, null);
verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class));
}
public void testApplyAsSystemUser() throws Exception {
ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class);
User user = new User("username", "r1", "r2");
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
SetOnce<Authentication> authenticationSetOnce = new SetOnce<>();
ActionFilterChain chain = (task, action, request1, listener1) -> {
authenticationSetOnce.set(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
};
Task task = mock(Task.class);
final boolean hasExistingAuthentication = randomBoolean();
final String action = "internal:foo";
if (hasExistingAuthentication) {
threadContext.putTransient(Authentication.AUTHENTICATION_KEY, authentication);
threadContext.putTransient(AuthorizationService.ORIGINATING_ACTION_KEY, "indices:foo");
} else {
assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
}
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[3];
callback.onResponse(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
return Void.TYPE;
}).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
doReturn(request).when(spy(filter)).unsign(user, action, request);
doAnswer((i) -> {
String text = (String) i.getArguments()[0];
return text;
}).when(cryptoService).sign(any(String.class));
filter.apply(task, action, request, listener, chain);
if (hasExistingAuthentication) {
assertEquals(authentication, threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
} else {
assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY));
}
assertNotNull(authenticationSetOnce.get());
assertNotEquals(authentication, authenticationSetOnce.get());
assertEquals(SystemUser.INSTANCE, authenticationSetOnce.get().getUser());
}
public void testApplyDestructiveOperations() throws Exception { public void testApplyDestructiveOperations() throws Exception {
ActionRequest request = new MockIndicesRequest( ActionRequest request = new MockIndicesRequest(
IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()), IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()),