diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/InternalClient.java b/plugin/src/main/java/org/elasticsearch/xpack/security/InternalClient.java index db53af79f94..d45c85bacb1 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/InternalClient.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/InternalClient.java @@ -100,7 +100,8 @@ public class InternalClient extends FilterClient { } final Consumer clearScroll = (response) -> { if (response != null && response.getScrollId() != null) { - ClearScrollRequest clearScrollRequest = client.prepareClearScroll().addScrollId(response.getScrollId()).request(); + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.addScrollId(response.getScrollId()); client.clearScroll(clearScrollRequest, ActionListener.wrap((r) -> {}, (e) -> {})); } }; @@ -120,10 +121,17 @@ public class InternalClient extends FilterClient { results.add(oneResult); } } - SearchScrollRequest scrollRequest = client.prepareSearchScroll(resp.getScrollId()) - .setScroll(request.scroll().keepAlive()).request(); - client.searchScroll(scrollRequest, this); + if (results.size() > resp.getHits().getTotalHits()) { + clearScroll.accept(lastResponse); + listener.onFailure(new IllegalStateException("scrolling returned more hits [" + results.size() + + "] than expected [" + resp.getHits().getTotalHits() + "] so bailing out to prevent unbounded " + + "memory consumption.")); + } else { + SearchScrollRequest scrollRequest = new SearchScrollRequest(resp.getScrollId()); + scrollRequest.scroll(request.scroll().keepAlive()); + client.searchScroll(scrollRequest, this); + } } else { clearScroll.accept(resp); // Finally, return the list of users diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/InternalClientIntegTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/InternalClientIntegTests.java index a1e08e5a6f2..7e87b39279b 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/InternalClientIntegTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/InternalClientIntegTests.java @@ -6,17 +6,32 @@ package org.elasticsearch.xpack.security; import org.apache.lucene.util.CollectionUtil; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.mockito.stubbing.Answer; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + public class InternalClientIntegTests extends ESSingleNodeTestCase { @@ -45,4 +60,50 @@ public class InternalClientIntegTests extends ESSingleNodeTestCase { assertEquals(list.get(i).intValue(), i); } } + + /** + * Tests that + * {@link InternalClient#fetchAllByEntity(Client, SearchRequest, org.elasticsearch.action.ActionListener, java.util.function.Function)} + * defends against scrolls broken in such a way that the remote Elasticsearch returns infinite results. While Elasticsearch + * shouldn't do this it has in the past and it is very when it does. It takes out the whole node. So + * this makes sure we defend against it properly. + */ + public void testFetchAllByEntityWithBrokenScroll() { + Client client = mock(Client.class); + SearchRequest request = new SearchRequest(); + + String scrollId = randomAlphaOfLength(5); + SearchHit[] hits = new SearchHit[] {new SearchHit(1)}; + InternalSearchResponse internalResponse = new InternalSearchResponse(new SearchHits(hits, 1, 1), null, null, null, false, false, 1); + SearchResponse response = new SearchResponse(internalResponse, scrollId, 1, 1, 0, ShardSearchFailure.EMPTY_ARRAY); + + Answer returnResponse = invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(response); + return null; + }; + doAnswer(returnResponse).when(client).search(eq(request), anyObject()); + /* The line below simulates the evil cluster. A working cluster would return + * a response with 0 hits. Our simulated broken cluster returns the same + * response over and over again. */ + doAnswer(returnResponse).when(client).searchScroll(anyObject(), anyObject()); + + AtomicReference failure = new AtomicReference<>(); + InternalClient.fetchAllByEntity(client, request, new ActionListener>() { + @Override + public void onResponse(Collection response) { + fail("This shouldn't succeed."); + } + + @Override + public void onFailure(Exception e) { + failure.set(e); + } + }, Function.identity()); + + assertNotNull("onFailure wasn't called", failure.get()); + assertEquals("scrolling returned more hits [2] than expected [1] so bailing out to prevent unbounded memory consumption.", + failure.get().getMessage()); + } }