[SAML] Find all tokens for a realm, not just the first 10 (elastic/x-pack-elasticsearch#3689)
This commit changes the token service to use a scroll based approach when finding all tokens by the realm. Without this, we may only find a few tokens and leave some active that need to be invalidated. relates elastic/x-pack-elasticsearch#3688 Original commit: elastic/x-pack-elasticsearch@20e97b6aae
This commit is contained in:
parent
2c46002c00
commit
0baa45d9b3
|
@ -67,6 +67,10 @@ public final class ScrollHelper {
|
|||
listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size()
|
||||
+ "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded "
|
||||
+ "memory consumption."));
|
||||
} else if (results.size() == resp.getHits().getTotalHits()) {
|
||||
clearScroll.accept(resp);
|
||||
// Finally, return the list of the entity
|
||||
listener.onResponse(Collections.unmodifiableList(results));
|
||||
} else {
|
||||
SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId());
|
||||
scrollRequest.scroll(request.scroll().keepAlive());
|
||||
|
@ -74,7 +78,7 @@ public final class ScrollHelper {
|
|||
}
|
||||
} else {
|
||||
clearScroll.accept(resp);
|
||||
// Finally, return the list of users
|
||||
// Finally, return the list of the entity
|
||||
listener.onResponse(Collections.unmodifiableList(results));
|
||||
}
|
||||
} catch (Exception e){
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
|
|||
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
|
||||
import org.opensaml.saml.saml2.core.LogoutResponse;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -122,7 +123,7 @@ public final class TransportSamlInvalidateSessionAction
|
|||
})), listener::onFailure));
|
||||
}
|
||||
|
||||
private List<Tuple<UserToken, String>> filterTokens(List<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
|
||||
private List<Tuple<UserToken, String>> filterTokens(Collection<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
|
||||
return tokens.stream()
|
||||
.filter(tup -> {
|
||||
Map<String, Object> actualMetadata = tup.v1().getMetadata();
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.action.index.IndexRequest;
|
|||
import org.elasticsearch.action.index.IndexResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
||||
import org.elasticsearch.action.update.UpdateRequest;
|
||||
|
@ -62,8 +63,10 @@ import org.elasticsearch.index.engine.VersionConflictEngineException;
|
|||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.XPackSettings;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.security.ScrollHelper;
|
||||
import org.elasticsearch.xpack.core.security.authc.Authentication;
|
||||
import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp;
|
||||
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
|
||||
|
@ -97,6 +100,7 @@ import java.time.temporal.ChronoUnit;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
|
@ -106,8 +110,7 @@ import java.util.Optional;
|
|||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException;
|
||||
import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;
|
||||
|
@ -815,7 +818,7 @@ public final class TokenService extends AbstractComponent {
|
|||
* Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against
|
||||
* the specified realm.
|
||||
*/
|
||||
public void findActiveTokensForRealm(String realmName, ActionListener<List<Tuple<UserToken, String>>> listener) {
|
||||
public void findActiveTokensForRealm(String realmName, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
|
||||
ensureEnabled();
|
||||
|
||||
if (Strings.isNullOrEmpty(realmName)) {
|
||||
|
@ -835,32 +838,30 @@ public final class TokenService extends AbstractComponent {
|
|||
.should(QueryBuilders.termQuery("refresh_token.invalidated", false))
|
||||
);
|
||||
|
||||
SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
|
||||
final SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
|
||||
.setScroll(TimeValue.timeValueSeconds(10L))
|
||||
.setQuery(boolQuery)
|
||||
.setVersion(false)
|
||||
.setSize(1000)
|
||||
.setFetchSource(true)
|
||||
.request();
|
||||
|
||||
final Supplier<ThreadContext.StoredContext> supplier = client.threadPool().getThreadContext().newRestorableContext(false);
|
||||
lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () ->
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
|
||||
ActionListener.<SearchResponse>wrap(searchResponse -> {
|
||||
if (searchResponse.isTimedOut()) {
|
||||
listener.onFailure(new ElasticsearchSecurityException("Failed to find user tokens"));
|
||||
} else {
|
||||
listener.onResponse(parseDocuments(searchResponse));
|
||||
}
|
||||
}, listener::onFailure),
|
||||
client::search));
|
||||
ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, listener), this::parseHit));
|
||||
}
|
||||
|
||||
private List<Tuple<UserToken, String>> parseDocuments(SearchResponse searchResponse) {
|
||||
return StreamSupport.stream(searchResponse.getHits().spliterator(), false).map(hit -> {
|
||||
final Map<String, Object> source = hit.getSourceAsMap();
|
||||
try {
|
||||
return parseTokensFromDocument(source);
|
||||
} catch (IOException e) {
|
||||
throw invalidGrantException("cannot read token from document");
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
private Tuple<UserToken, String> parseHit(SearchHit hit) {
|
||||
final Map<String, Object> source = hit.getSourceAsMap();
|
||||
if (source == null) {
|
||||
throw new IllegalStateException("token document did not have source but source should have been fetched");
|
||||
}
|
||||
|
||||
try {
|
||||
return parseTokensFromDocument(source);
|
||||
} catch (IOException e) {
|
||||
throw invalidGrantException("cannot read token from document");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -81,8 +81,8 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
|
|||
SearchRequest request = new SearchRequest();
|
||||
|
||||
String scrollId = randomAlphaOfLength(5);
|
||||
SearchHit[] hits = new SearchHit[] {new SearchHit(1)};
|
||||
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 1, 1), null, null, null, false, false, 1);
|
||||
SearchHit[] hits = new SearchHit[] {new SearchHit(1), new SearchHit(2)};
|
||||
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 3, 1), null, null, null, false, false, 1);
|
||||
SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY,
|
||||
SearchResponse.Clusters.EMPTY);
|
||||
|
||||
|
@ -112,7 +112,7 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
|
|||
}, Function.identity());
|
||||
|
||||
assertNotNull("onFailure wasn't called", failure.get());
|
||||
assertEquals("scrolling returned more hits [2] than expected [1] so bailing out to prevent unbounded memory consumption.",
|
||||
assertEquals("scrolling returned more hits [4] than expected [3] so bailing out to prevent unbounded memory consumption.",
|
||||
failure.get().getMessage());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,9 @@ import org.elasticsearch.action.ActionResponse;
|
|||
import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.index.IndexResponse;
|
||||
import org.elasticsearch.action.search.ClearScrollAction;
|
||||
import org.elasticsearch.action.search.ClearScrollRequest;
|
||||
import org.elasticsearch.action.search.ClearScrollResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
|
@ -144,6 +147,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
|
|||
new SearchResponseSections(new SearchHits(hits, hits.length, 0f),
|
||||
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
|
||||
listener.onResponse((Response) response);
|
||||
} else if (ClearScrollAction.NAME.equals(action.name())) {
|
||||
assertThat(request, instanceOf(ClearScrollRequest.class));
|
||||
ClearScrollRequest scrollRequest = (ClearScrollRequest) request;
|
||||
assertEquals("_scrollId1", scrollRequest.getScrollIds().get(0));
|
||||
ClearScrollResponse response = new ClearScrollResponse(true, 1);
|
||||
listener.onResponse((Response) response);
|
||||
} else {
|
||||
super.doExecute(action, request, listener);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue