From c5cab37db634ae167eb3d4b8a8362d1657e708dd Mon Sep 17 00:00:00 2001 From: Jay Modi Date: Wed, 11 Jan 2017 13:41:14 -0500 Subject: [PATCH] security: always restore the ThreadContext after invoking an action This change ensure that the ThreadContext is always restored after an action has been invoked when going through the SecurityActionFilter and authentication and authorization is enabled. Original commit: elastic/x-pack-elasticsearch@5da70bd6fa2958c967f5a4a60994c2d078ec1e32 --- .../action/filter/SecurityActionFilter.java | 11 +- .../accesscontrol/IndicesAccessControl.java | 3 - .../DocumentLevelSecurityTests.java | 82 ++++++++++ .../integration/FieldLevelSecurityTests.java | 151 ++++++++++++++++++ .../filter/SecurityActionFilterTests.java | 84 +++++++++- 5 files changed, 319 insertions(+), 12 deletions(-) diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java index 10933d127cd..8a3f742ba89 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java @@ -100,13 +100,8 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil } if (licenseState.isAuthAllowed()) { - // only restore the context if it is not empty. This is needed because sometimes a response is sent to the user - // and then a cleanup action is executed (like for search without a scroll) - final boolean restoreOriginalContext = securityContext.getAuthentication() != null; final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action); - // we should always restore the original here because we forcefully changed to the system user - final ThreadContext.StoredContext toRestore = - restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {}; + final ThreadContext.StoredContext toRestore = threadContext.newStoredContext(); final ActionListener signingListener = new ContextPreservingActionListener<>(threadContext, toRestore, ActionListener.wrap(r -> { try { @@ -127,7 +122,9 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil } }); } else { - applyInternal(action, request, authenticatedListener); + try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) { + applyInternal(action, request, authenticatedListener); + } } } catch (Exception e) { listener.onFailure(e); diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/authz/accesscontrol/IndicesAccessControl.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/authz/accesscontrol/IndicesAccessControl.java index 2688823acff..16b83dfa646 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/authz/accesscontrol/IndicesAccessControl.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/authz/accesscontrol/IndicesAccessControl.java @@ -11,12 +11,9 @@ import org.elasticsearch.xpack.security.authz.IndicesAndAliasesResolver; import org.elasticsearch.xpack.security.authz.permission.FieldPermissions; import java.util.Collections; -import java.util.HashSet; import java.util.Map; import java.util.Set; -import static java.util.Collections.unmodifiableSet; - /** * Encapsulates the field and document permissions per concrete index based on the current request. */ diff --git a/elasticsearch/src/test/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java b/elasticsearch/src/test/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java index 2ed5755c3be..d884e947ebe 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.MultiGetResponse; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.termvectors.MultiTermVectorsResponse; import org.elasticsearch.action.termvectors.TermVectorsRequest; @@ -26,6 +27,8 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.bucket.children.Children; import org.elasticsearch.search.aggregations.bucket.global.Global; import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.search.sort.SortMode; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.SecurityIntegTestCase; import org.elasticsearch.xpack.XPackSettings; @@ -283,6 +286,85 @@ public class DocumentLevelSecurityTests extends SecurityIntegTestCase { assertThat(response.getResponses()[0].getResponse().isExists(), is(false)); } + public void testMSearch() throws Exception { + assertAcked(client().admin().indices().prepareCreate("test1") + .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text", "id", "type=integer") + ); + assertAcked(client().admin().indices().prepareCreate("test2") + .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text", "id", "type=integer") + ); + + client().prepareIndex("test1", "type1", "1").setSource("field1", "value1", "id", 1).get(); + client().prepareIndex("test1", "type1", "2").setSource("field2", "value2", "id", 2).get(); + client().prepareIndex("test1", "type1", "3").setSource("field3", "value3", "id", 3).get(); + client().prepareIndex("test2", "type1", "1").setSource("field1", "value1", "id", 1).get(); + client().prepareIndex("test2", "type1", "2").setSource("field2", "value2", "id", 2).get(); + client().prepareIndex("test2", "type1", "3").setSource("field3", "value3", "id", 3).get(); + client().admin().indices().prepareRefresh("test1", "test2").get(); + + MultiSearchResponse response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(1)); + + assertFalse(response.getResponses()[1].isFailure()); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(1)); + + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(2)); + + assertFalse(response.getResponses()[1].isFailure()); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(2)); + + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user3", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").addSort(SortBuilders.fieldSort("id").sortMode(SortMode.MIN)) + .setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").addSort(SortBuilders.fieldSort("id").sortMode(SortMode.MIN)) + .setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(2L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("id"), is(1)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(1).getSource().get("id"), is(2)); + + assertFalse(response.getResponses()[1].isFailure()); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(2L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("id"), is(1)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(1).getSource().get("id"), is(2)); + } + public void testTVApi() throws Exception { assertAcked(client().admin().indices().prepareCreate("test") .addMapping("type1", "field1", "type=text,term_vector=with_positions_offsets_payloads", diff --git a/elasticsearch/src/test/java/org/elasticsearch/integration/FieldLevelSecurityTests.java b/elasticsearch/src/test/java/org/elasticsearch/integration/FieldLevelSecurityTests.java index 3a743325167..883e49fb09c 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/integration/FieldLevelSecurityTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/integration/FieldLevelSecurityTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.fieldstats.FieldStatsResponse; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.MultiGetResponse; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.termvectors.MultiTermVectorsResponse; import org.elasticsearch.action.termvectors.TermVectorsRequest; @@ -20,6 +21,7 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.indices.IndicesRequestCache; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -488,6 +490,155 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase { assertThat(response.getResponses()[0].getResponse().getSource().get("field2").toString(), equalTo("value2")); } + public void testMSearchApi() throws Exception { + assertAcked(client().admin().indices().prepareCreate("test1") + .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text") + ); + assertAcked(client().admin().indices().prepareCreate("test2") + .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text") + ); + + client().prepareIndex("test1", "type1", "1") + .setSource("field1", "value1", "field2", "value2", "field3", "value3").get(); + client().prepareIndex("test2", "type1", "1") + .setSource("field1", "value1", "field2", "value2", "field3", "value3").get(); + client().admin().indices().prepareRefresh("test1", "test2").get(); + + // user1 is granted access to field1 only + MultiSearchResponse response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(1)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(1)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + + // user2 is granted access to field2 only + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(1)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(1)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + + // user3 is granted access to field1 and field2 + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user3", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + + // user4 is granted access to no fields, so the search response does say the doc exist, but no fields are returned + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user4", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(0)); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(0)); + + // user5 has no field level security configured, so all fields are returned + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user5", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + + // user6 has access to field* + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user6", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + + // user7 has roles with field level security and without field level security + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user7", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(3)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field3"), is("value3")); + + // user8 has roles with field level security with access to field1 and field2 + response = client() + .filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user8", USERS_PASSWD))) + .prepareMultiSearch() + .add(client().prepareSearch("test1").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .add(client().prepareSearch("test2").setTypes("type1").setQuery(QueryBuilders.matchAllQuery())) + .get(); + assertFalse(response.getResponses()[0].isFailure()); + assertThat(response.getResponses()[0].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[0].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + assertThat(response.getResponses()[1].getResponse().getHits().getTotalHits(), is(1L)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().size(), is(2)); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field1"), is("value1")); + assertThat(response.getResponses()[1].getResponse().getHits().getAt(0).getSource().get("field2"), is("value2")); + } + public void testFieldStatsApi() throws Exception { assertAcked(client().admin().indices().prepareCreate("test") .addMapping("type1", "field1", "type=text", "field2", "type=text", "field3", "type=text") diff --git a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java index 8b1c095a431..a0c67677381 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.action.filter; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -60,6 +61,7 @@ public class SecurityActionFilterTests extends ESTestCase { private AuditTrailService auditTrail; private XPackLicenseState licenseState; private SecurityActionFilter filter; + private ThreadContext threadContext; private boolean failDestructiveOperations; @Before @@ -72,14 +74,17 @@ public class SecurityActionFilterTests extends ESTestCase { when(licenseState.isAuthAllowed()).thenReturn(true); when(licenseState.isStatsAndHealthAllowed()).thenReturn(true); ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + threadContext = new ThreadContext(Settings.EMPTY); + when(threadPool.getThreadContext()).thenReturn(threadContext); failDestructiveOperations = randomBoolean(); Settings settings = Settings.builder() .put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), failDestructiveOperations).build(); DestructiveOperations destructiveOperations = new DestructiveOperations(settings, new ClusterSettings(settings, Collections.singleton(DestructiveOperations.REQUIRES_NAME_SETTING))); + + SecurityContext securityContext = new SecurityContext(settings, threadContext, cryptoService); filter = new SecurityActionFilter(Settings.EMPTY, authcService, authzService, cryptoService, auditTrail, - licenseState, new HashSet<>(), threadPool, mock(SecurityContext.class), destructiveOperations); + licenseState, new HashSet<>(), threadPool, securityContext, destructiveOperations); } public void testApply() throws Exception { @@ -108,6 +113,81 @@ public class SecurityActionFilterTests extends ESTestCase { verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class)); } + public void testApplyRestoresThreadContext() throws Exception { + ActionRequest request = mock(ActionRequest.class); + ActionListener listener = mock(ActionListener.class); + ActionFilterChain chain = mock(ActionFilterChain.class); + Task task = mock(Task.class); + User user = new User("username", "r1", "r2"); + Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); + doAnswer((i) -> { + ActionListener callback = + (ActionListener) i.getArguments()[3]; + assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + threadContext.putTransient(Authentication.AUTHENTICATION_KEY, authentication); + callback.onResponse(authentication); + return Void.TYPE; + }).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); + final Role empty = Role.EMPTY; + doAnswer((i) -> { + ActionListener callback = + (ActionListener) i.getArguments()[1]; + assertEquals(authentication, threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + callback.onResponse(empty); + return Void.TYPE; + }).when(authzService).roles(any(User.class), any(ActionListener.class)); + doReturn(request).when(spy(filter)).unsign(user, "_action", request); + assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + + filter.apply(task, "_action", request, listener, chain); + + assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + verify(authzService).authorize(authentication, "_action", request, empty, null); + verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class)); + } + + public void testApplyAsSystemUser() throws Exception { + ActionRequest request = mock(ActionRequest.class); + ActionListener listener = mock(ActionListener.class); + User user = new User("username", "r1", "r2"); + Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); + SetOnce authenticationSetOnce = new SetOnce<>(); + ActionFilterChain chain = (task, action, request1, listener1) -> { + authenticationSetOnce.set(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + }; + Task task = mock(Task.class); + final boolean hasExistingAuthentication = randomBoolean(); + final String action = "internal:foo"; + if (hasExistingAuthentication) { + threadContext.putTransient(Authentication.AUTHENTICATION_KEY, authentication); + threadContext.putTransient(AuthorizationService.ORIGINATING_ACTION_KEY, "indices:foo"); + } else { + assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + } + doAnswer((i) -> { + ActionListener callback = + (ActionListener) i.getArguments()[3]; + callback.onResponse(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + return Void.TYPE; + }).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); + doReturn(request).when(spy(filter)).unsign(user, action, request); + doAnswer((i) -> { + String text = (String) i.getArguments()[0]; + return text; + }).when(cryptoService).sign(any(String.class)); + + filter.apply(task, action, request, listener, chain); + + if (hasExistingAuthentication) { + assertEquals(authentication, threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + } else { + assertNull(threadContext.getTransient(Authentication.AUTHENTICATION_KEY)); + } + assertNotNull(authenticationSetOnce.get()); + assertNotEquals(authentication, authenticationSetOnce.get()); + assertEquals(SystemUser.INSTANCE, authenticationSetOnce.get().getUser()); + } + public void testApplyDestructiveOperations() throws Exception { ActionRequest request = new MockIndicesRequest( IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()),