[SAML] Find all tokens for a realm, not just the first 10 (elastic/x-pack-elasticsearch#3689)
This commit changes the token service to use a scroll based approach when finding all tokens by the realm. Without this, we may only find a few tokens and leave some active that need to be invalidated. relates elastic/x-pack-elasticsearch#3688 Original commit: elastic/x-pack-elasticsearch@20e97b6aae
This commit is contained in:
parent
2c46002c00
commit
0baa45d9b3
|
@ -67,6 +67,10 @@ public final class ScrollHelper {
|
||||||
listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size()
|
listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size()
|
||||||
+ "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded "
|
+ "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded "
|
||||||
+ "memory consumption."));
|
+ "memory consumption."));
|
||||||
|
} else if (results.size() == resp.getHits().getTotalHits()) {
|
||||||
|
clearScroll.accept(resp);
|
||||||
|
// Finally, return the list of the entity
|
||||||
|
listener.onResponse(Collections.unmodifiableList(results));
|
||||||
} else {
|
} else {
|
||||||
SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId());
|
SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId());
|
||||||
scrollRequest.scroll(request.scroll().keepAlive());
|
scrollRequest.scroll(request.scroll().keepAlive());
|
||||||
|
@ -74,7 +78,7 @@ public final class ScrollHelper {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
clearScroll.accept(resp);
|
clearScroll.accept(resp);
|
||||||
// Finally, return the list of users
|
// Finally, return the list of the entity
|
||||||
listener.onResponse(Collections.unmodifiableList(results));
|
listener.onResponse(Collections.unmodifiableList(results));
|
||||||
}
|
}
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
|
||||||
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
|
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
|
||||||
import org.opensaml.saml.saml2.core.LogoutResponse;
|
import org.opensaml.saml.saml2.core.LogoutResponse;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -122,7 +123,7 @@ public final class TransportSamlInvalidateSessionAction
|
||||||
})), listener::onFailure));
|
})), listener::onFailure));
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Tuple<UserToken, String>> filterTokens(List<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
|
private List<Tuple<UserToken, String>> filterTokens(Collection<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
|
||||||
return tokens.stream()
|
return tokens.stream()
|
||||||
.filter(tup -> {
|
.filter(tup -> {
|
||||||
Map<String, Object> actualMetadata = tup.v1().getMetadata();
|
Map<String, Object> actualMetadata = tup.v1().getMetadata();
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.index.IndexResponse;
|
import org.elasticsearch.action.index.IndexResponse;
|
||||||
import org.elasticsearch.action.search.SearchRequest;
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
|
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||||
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
|
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
|
||||||
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
||||||
import org.elasticsearch.action.update.UpdateRequest;
|
import org.elasticsearch.action.update.UpdateRequest;
|
||||||
|
@ -62,8 +63,10 @@ import org.elasticsearch.index.engine.VersionConflictEngineException;
|
||||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.xpack.core.XPackSettings;
|
import org.elasticsearch.xpack.core.XPackSettings;
|
||||||
import org.elasticsearch.xpack.core.XPackField;
|
import org.elasticsearch.xpack.core.XPackField;
|
||||||
|
import org.elasticsearch.xpack.core.security.ScrollHelper;
|
||||||
import org.elasticsearch.xpack.core.security.authc.Authentication;
|
import org.elasticsearch.xpack.core.security.authc.Authentication;
|
||||||
import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp;
|
import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp;
|
||||||
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
|
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
|
||||||
|
@ -97,6 +100,7 @@ import java.time.temporal.ChronoUnit;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
@ -106,8 +110,7 @@ import java.util.Optional;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
import java.util.stream.Collectors;
|
import java.util.function.Supplier;
|
||||||
import java.util.stream.StreamSupport;
|
|
||||||
|
|
||||||
import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException;
|
import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException;
|
||||||
import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;
|
import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;
|
||||||
|
@ -815,7 +818,7 @@ public final class TokenService extends AbstractComponent {
|
||||||
* Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against
|
* Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against
|
||||||
* the specified realm.
|
* the specified realm.
|
||||||
*/
|
*/
|
||||||
public void findActiveTokensForRealm(String realmName, ActionListener<List<Tuple<UserToken, String>>> listener) {
|
public void findActiveTokensForRealm(String realmName, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
|
||||||
ensureEnabled();
|
ensureEnabled();
|
||||||
|
|
||||||
if (Strings.isNullOrEmpty(realmName)) {
|
if (Strings.isNullOrEmpty(realmName)) {
|
||||||
|
@ -835,32 +838,30 @@ public final class TokenService extends AbstractComponent {
|
||||||
.should(QueryBuilders.termQuery("refresh_token.invalidated", false))
|
.should(QueryBuilders.termQuery("refresh_token.invalidated", false))
|
||||||
);
|
);
|
||||||
|
|
||||||
SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
|
final SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME)
|
||||||
|
.setScroll(TimeValue.timeValueSeconds(10L))
|
||||||
.setQuery(boolQuery)
|
.setQuery(boolQuery)
|
||||||
.setVersion(false)
|
.setVersion(false)
|
||||||
|
.setSize(1000)
|
||||||
|
.setFetchSource(true)
|
||||||
.request();
|
.request();
|
||||||
|
|
||||||
|
final Supplier<ThreadContext.StoredContext> supplier = client.threadPool().getThreadContext().newRestorableContext(false);
|
||||||
lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () ->
|
lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () ->
|
||||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
|
ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, listener), this::parseHit));
|
||||||
ActionListener.<SearchResponse>wrap(searchResponse -> {
|
|
||||||
if (searchResponse.isTimedOut()) {
|
|
||||||
listener.onFailure(new ElasticsearchSecurityException("Failed to find user tokens"));
|
|
||||||
} else {
|
|
||||||
listener.onResponse(parseDocuments(searchResponse));
|
|
||||||
}
|
|
||||||
}, listener::onFailure),
|
|
||||||
client::search));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Tuple<UserToken, String>> parseDocuments(SearchResponse searchResponse) {
|
private Tuple<UserToken, String> parseHit(SearchHit hit) {
|
||||||
return StreamSupport.stream(searchResponse.getHits().spliterator(), false).map(hit -> {
|
final Map<String, Object> source = hit.getSourceAsMap();
|
||||||
final Map<String, Object> source = hit.getSourceAsMap();
|
if (source == null) {
|
||||||
try {
|
throw new IllegalStateException("token document did not have source but source should have been fetched");
|
||||||
return parseTokensFromDocument(source);
|
}
|
||||||
} catch (IOException e) {
|
|
||||||
throw invalidGrantException("cannot read token from document");
|
try {
|
||||||
}
|
return parseTokensFromDocument(source);
|
||||||
}).collect(Collectors.toList());
|
} catch (IOException e) {
|
||||||
|
throw invalidGrantException("cannot read token from document");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -81,8 +81,8 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
|
||||||
SearchRequest request = new SearchRequest();
|
SearchRequest request = new SearchRequest();
|
||||||
|
|
||||||
String scrollId = randomAlphaOfLength(5);
|
String scrollId = randomAlphaOfLength(5);
|
||||||
SearchHit[] hits = new SearchHit[] {new SearchHit(1)};
|
SearchHit[] hits = new SearchHit[] {new SearchHit(1), new SearchHit(2)};
|
||||||
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 1, 1), null, null, null, false, false, 1);
|
InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 3, 1), null, null, null, false, false, 1);
|
||||||
SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY,
|
SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY,
|
||||||
SearchResponse.Clusters.EMPTY);
|
SearchResponse.Clusters.EMPTY);
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ public class ScrollHelperIntegTests extends ESSingleNodeTestCase {
|
||||||
}, Function.identity());
|
}, Function.identity());
|
||||||
|
|
||||||
assertNotNull("onFailure wasn't called", failure.get());
|
assertNotNull("onFailure wasn't called", failure.get());
|
||||||
assertEquals("scrolling returned more hits [2] than expected [1] so bailing out to prevent unbounded memory consumption.",
|
assertEquals("scrolling returned more hits [4] than expected [3] so bailing out to prevent unbounded memory consumption.",
|
||||||
failure.get().getMessage());
|
failure.get().getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,9 @@ import org.elasticsearch.action.ActionResponse;
|
||||||
import org.elasticsearch.action.index.IndexAction;
|
import org.elasticsearch.action.index.IndexAction;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.index.IndexResponse;
|
import org.elasticsearch.action.index.IndexResponse;
|
||||||
|
import org.elasticsearch.action.search.ClearScrollAction;
|
||||||
|
import org.elasticsearch.action.search.ClearScrollRequest;
|
||||||
|
import org.elasticsearch.action.search.ClearScrollResponse;
|
||||||
import org.elasticsearch.action.search.SearchAction;
|
import org.elasticsearch.action.search.SearchAction;
|
||||||
import org.elasticsearch.action.search.SearchRequest;
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
|
@ -144,6 +147,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
|
||||||
new SearchResponseSections(new SearchHits(hits, hits.length, 0f),
|
new SearchResponseSections(new SearchHits(hits, hits.length, 0f),
|
||||||
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
|
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
|
||||||
listener.onResponse((Response) response);
|
listener.onResponse((Response) response);
|
||||||
|
} else if (ClearScrollAction.NAME.equals(action.name())) {
|
||||||
|
assertThat(request, instanceOf(ClearScrollRequest.class));
|
||||||
|
ClearScrollRequest scrollRequest = (ClearScrollRequest) request;
|
||||||
|
assertEquals("_scrollId1", scrollRequest.getScrollIds().get(0));
|
||||||
|
ClearScrollResponse response = new ClearScrollResponse(true, 1);
|
||||||
|
listener.onResponse((Response) response);
|
||||||
} else {
|
} else {
|
||||||
super.doExecute(action, request, listener);
|
super.doExecute(action, request, listener);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue