diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/ScrollHelper.java b/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/ScrollHelper.java index a70c6c2f9ff..a481f880311 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/ScrollHelper.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/ScrollHelper.java @@ -67,6 +67,10 @@ public final class ScrollHelper { listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size() + "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded " + "memory consumption.")); + } else if (results.size() == resp.getHits().getTotalHits()) { + clearScroll.accept(resp); + // Finally, return the list of the entity + listener.onResponse(Collections.unmodifiableList(results)); } else { SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId()); scrollRequest.scroll(request.scroll().keepAlive()); @@ -74,7 +78,7 @@ public final class ScrollHelper { } } else { clearScroll.accept(resp); - // Finally, return the list of users + // Finally, return the list of the entity listener.onResponse(Collections.unmodifiableList(results)); } } catch (Exception e){ diff --git a/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java b/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java index 517b13933cf..143b3ffd64b 100644 --- a/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java +++ b/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect; import org.elasticsearch.xpack.security.authc.saml.SamlUtils; import org.opensaml.saml.saml2.core.LogoutResponse; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -122,7 +123,7 @@ public final class TransportSamlInvalidateSessionAction })), listener::onFailure)); } - private List> filterTokens(List> tokens, Map requiredMetadata) { + private List> filterTokens(Collection> tokens, Map requiredMetadata) { return tokens.stream() .filter(tup -> { Map actualMetadata = tup.v1().getMetadata(); diff --git a/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 5ebd7933926..49f9b9c01a9 100644 --- a/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.update.UpdateRequest; @@ -62,8 +63,10 @@ import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.security.ScrollHelper; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp; import org.elasticsearch.xpack.core.security.authc.TokenMetaData; @@ -97,6 +100,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -106,8 +110,7 @@ import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; +import java.util.function.Supplier; import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException; import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK; @@ -815,7 +818,7 @@ public final class TokenService extends AbstractComponent { * Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against * the specified realm. */ - public void findActiveTokensForRealm(String realmName, ActionListener>> listener) { + public void findActiveTokensForRealm(String realmName, ActionListener>> listener) { ensureEnabled(); if (Strings.isNullOrEmpty(realmName)) { @@ -835,32 +838,30 @@ public final class TokenService extends AbstractComponent { .should(QueryBuilders.termQuery("refresh_token.invalidated", false)) ); - SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME) + final SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME) + .setScroll(TimeValue.timeValueSeconds(10L)) .setQuery(boolQuery) .setVersion(false) + .setSize(1000) + .setFetchSource(true) .request(); + final Supplier supplier = client.threadPool().getThreadContext().newRestorableContext(false); lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, - ActionListener.wrap(searchResponse -> { - if (searchResponse.isTimedOut()) { - listener.onFailure(new ElasticsearchSecurityException("Failed to find user tokens")); - } else { - listener.onResponse(parseDocuments(searchResponse)); - } - }, listener::onFailure), - client::search)); + ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, listener), this::parseHit)); } - private List> parseDocuments(SearchResponse searchResponse) { - return StreamSupport.stream(searchResponse.getHits().spliterator(), false).map(hit -> { - final Map source = hit.getSourceAsMap(); - try { - return parseTokensFromDocument(source); - } catch (IOException e) { - throw invalidGrantException("cannot read token from document"); - } - }).collect(Collectors.toList()); + private Tuple parseHit(SearchHit hit) { + final Map source = hit.getSourceAsMap(); + if (source == null) { + throw new IllegalStateException("token document did not have source but source should have been fetched"); + } + + try { + return parseTokensFromDocument(source); + } catch (IOException e) { + throw invalidGrantException("cannot read token from document"); + } } /** diff --git a/plugin/security/src/test/java/org/elasticsearch/xpack/security/ScrollHelperIntegTests.java b/plugin/security/src/test/java/org/elasticsearch/xpack/security/ScrollHelperIntegTests.java index c79a369cd1b..7ab26b0c33f 100644 --- a/plugin/security/src/test/java/org/elasticsearch/xpack/security/ScrollHelperIntegTests.java +++ b/plugin/security/src/test/java/org/elasticsearch/xpack/security/ScrollHelperIntegTests.java @@ -81,8 +81,8 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase { SearchRequest request = new SearchRequest(); String scrollId = randomAlphaOfLength(5); - SearchHit[] hits = new SearchHit[] {new SearchHit(1)}; - InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 1, 1), null, null, null, false, false, 1); + SearchHit[] hits = new SearchHit[] {new SearchHit(1), new SearchHit(2)}; + InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 3, 1), null, null, null, false, false, 1); SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY); @@ -112,7 +112,7 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase { }, Function.identity()); assertNotNull("onFailure wasn't called", failure.get()); - assertEquals("scrolling returned more hits [2] than expected [1] so bailing out to prevent unbounded memory consumption.", + assertEquals("scrolling returned more hits [4] than expected [3] so bailing out to prevent unbounded memory consumption.", failure.get().getMessage()); } } diff --git a/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java b/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java index 2bd300a694b..e8f7eaf877b 100644 --- a/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java +++ b/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java @@ -14,6 +14,9 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.ClearScrollResponse; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -144,6 +147,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase { new SearchResponseSections(new SearchHits(hits, hits.length, 0f), null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null); listener.onResponse((Response) response); + } else if (ClearScrollAction.NAME.equals(action.name())) { + assertThat(request, instanceOf(ClearScrollRequest.class)); + ClearScrollRequest scrollRequest = (ClearScrollRequest) request; + assertEquals("_scrollId1", scrollRequest.getScrollIds().get(0)); + ClearScrollResponse response = new ClearScrollResponse(true, 1); + listener.onResponse((Response) response); } else { super.doExecute(action, request, listener); }