[SAML] Find all tokens for a realm, not just the first 10 (elastic/x-pack-elasticsearch#3689)

This commit changes the token service to use a scroll based approach when finding all tokens by
the realm. Without this, we may only find a few tokens and leave some active that need to be
invalidated.

relates elastic/x-pack-elasticsearch#3688

Original commit: elastic/x-pack-elasticsearch@20e97b6aae
This commit is contained in:
Jay Modi 2018-01-24 11:07:51 -07:00 committed by GitHub
parent 2c46002c00
commit 0baa45d9b3
5 changed files with 42 additions and 27 deletions

View File

@ -67,6 +67,10 @@ public final class ScrollHelper {
listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size()
+ "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded "
+ "memory consumption."));
} else if (results.size() == resp.getHits().getTotalHits()) {
clearScroll.accept(resp);
// Finally, return the list of the entity
listener.onResponse(Collections.unmodifiableList(results));
} else {
SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId());
scrollRequest.scroll(request.scroll().keepAlive());
@ -74,7 +78,7 @@ public final class ScrollHelper {
}
} else {
clearScroll.accept(resp);
// Finally, return the list of users
// Finally, return the list of the entity
listener.onResponse(Collections.unmodifiableList(results));
}
} catch (Exception e){

View File

@ -29,6 +29,7 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
import org.opensaml.saml.saml2.core.LogoutResponse;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -122,7 +123,7 @@ public final class TransportSamlInvalidateSessionAction
})), listener::onFailure));
}
private List<Tuple<UserToken, String>> filterTokens(List<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
private List<Tuple<UserToken, String>> filterTokens(Collection<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
return tokens.stream()
.filter(tup -> {
Map<String, Object> actualMetadata = tup.v1().getMetadata();

View File

@ -26,6 +26,7 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.update.UpdateRequest;
@ -62,8 +63,10 @@ import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.security.ScrollHelper;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp;
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
@ -97,6 +100,7 @@ import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
@ -106,8 +110,7 @@ import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import java.util.function.Supplier;
import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException;
import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;
@ -815,7 +818,7 @@ public final class TokenService extends AbstractComponent {
* Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against
* the specified realm.
*/
public void findActiveTokensForRealm(String realmName, ActionListener<List<Tuple<UserToken, String>>> listener) {
public void findActiveTokensForRealm(String realmName, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
ensureEnabled();
if (Strings.isNullOrEmpty(realmName)) {
@ -835,32 +838,30 @@ public final class TokenService extends AbstractComponent {
.should(QueryBuilders.termQuery("refresh_token.invalidated", false))
);
SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
final SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
.setScroll(TimeValue.timeValueSeconds(10L))
.setQuery(boolQuery)
.setVersion(false)
.setSize(1000)
.setFetchSource(true)
.request();
final Supplier<ThreadContext.StoredContext> supplier = client.threadPool().getThreadContext().newRestorableContext(false);
lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () ->
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
ActionListener.<SearchResponse>wrap(searchResponse -> {
if (searchResponse.isTimedOut()) {
listener.onFailure(new ElasticsearchSecurityException("Failed to find user tokens"));
} else {
listener.onResponse(parseDocuments(searchResponse));
}
}, listener::onFailure),
client::search));
ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, listener), this::parseHit));
}
private List<Tuple<UserToken, String>> parseDocuments(SearchResponse searchResponse) {
return StreamSupport.stream(searchResponse.getHits().spliterator(), false).map(hit -> {
final Map<String, Object> source = hit.getSourceAsMap();
try {
return parseTokensFromDocument(source);
} catch (IOException e) {
throw invalidGrantException("cannot read token from document");
}
}).collect(Collectors.toList());
private Tuple<UserToken, String> parseHit(SearchHit hit) {
final Map<String, Object> source = hit.getSourceAsMap();
if (source == null) {
throw new IllegalStateException("token document did not have source but source should have been fetched");
}
try {
return parseTokensFromDocument(source);
} catch (IOException e) {
throw invalidGrantException("cannot read token from document");
}
}
/**

View File

@ -81,8 +81,8 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
SearchRequest request = new SearchRequest();
String scrollId = randomAlphaOfLength(5);
SearchHit[] hits = new SearchHit[] {new SearchHit(1)};
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 1, 1), null, null, null, false, false, 1);
SearchHit[] hits = new SearchHit[] {new SearchHit(1), new SearchHit(2)};
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 3, 1), null, null, null, false, false, 1);
SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY);
@ -112,7 +112,7 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
}, Function.identity());
assertNotNull("onFailure wasn't called", failure.get());
assertEquals("scrolling returned more hits [2] than expected [1] so bailing out to prevent unbounded memory consumption.",
assertEquals("scrolling returned more hits [4] than expected [3] so bailing out to prevent unbounded memory consumption.",
failure.get().getMessage());
}
}

View File

@ -14,6 +14,9 @@ import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.ClearScrollAction;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
@ -144,6 +147,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
new SearchResponseSections(new SearchHits(hits, hits.length, 0f),
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
listener.onResponse((Response) response);
} else if (ClearScrollAction.NAME.equals(action.name())) {
assertThat(request, instanceOf(ClearScrollRequest.class));
ClearScrollRequest scrollRequest = (ClearScrollRequest) request;
assertEquals("_scrollId1", scrollRequest.getScrollIds().get(0));
ClearScrollResponse response = new ClearScrollResponse(true, 1);
listener.onResponse((Response) response);
} else {
super.doExecute(action, request, listener);
}