Add searchScroll method to high level REST client (#24938)

This commit is contained in:
Luca Cavanna 2017-06-01 10:56:17 +02:00 committed by GitHub
parent c9aeb34d50
commit 856235fac2
4 changed files with 98 additions and 20 deletions

View File

@ -34,6 +34,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.ActiveShardCount;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
@ -63,7 +64,7 @@ import java.util.StringJoiner;
final class Request {
private static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON;
static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON;
final String method;
final String endpoint;
@ -338,6 +339,11 @@ final class Request {
return new Request(HttpGet.METHOD_NAME, endpoint, params.getParams(), entity);
}
static Request searchScroll(SearchScrollRequest searchScrollRequest) throws IOException {
HttpEntity entity = createEntity(searchScrollRequest, REQUEST_BODY_CONTENT_TYPE);
return new Request("GET", "/_search/scroll", Collections.emptyMap(), entity);
}
private static HttpEntity createEntity(ToXContent toXContent, XContentType xContentType) throws IOException {
BytesRef source = XContentHelper.toXContent(toXContent, xContentType, false).toBytesRef();
return new ByteArrayEntity(source.bytes, source.offset, source.length, ContentType.create(xContentType.mediaType()));
@ -483,7 +489,7 @@ final class Request {
return this;
}
Params withIndicesOptions (IndicesOptions indicesOptions) {
Params withIndicesOptions(IndicesOptions indicesOptions) {
putParam("ignore_unavailable", Boolean.toString(indicesOptions.ignoreUnavailable()));
putParam("allow_no_indices", Boolean.toString(indicesOptions.allowNoIndices()));
String expandWildcards;

View File

@ -38,6 +38,7 @@ import org.elasticsearch.action.main.MainRequest;
import org.elasticsearch.action.main.MainResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.common.CheckedFunction;
@ -325,6 +326,27 @@ public class RestHighLevelClient {
performRequestAsyncAndParseEntity(searchRequest, Request::search, SearchResponse::fromXContent, listener, emptySet(), headers);
}
/**
* Executes a search using the Search Scroll api
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-scroll.html">Search Scroll
* API on elastic.co</a>
*/
public SearchResponse searchScroll(SearchScrollRequest searchScrollRequest, Header... headers) throws IOException {
return performRequestAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent, emptySet(), headers);
}
/**
* Asynchronously executes a search using the Search Scroll api
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-scroll.html">Search Scroll
* API on elastic.co</a>
*/
public void searchScrollAsync(SearchScrollRequest searchScrollRequest, ActionListener<SearchResponse> listener, Header... headers) {
performRequestAsyncAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent,
listener, emptySet(), headers);
}
private <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
@ -354,6 +376,7 @@ public class RestHighLevelClient {
}
throw parseResponseException(e);
}
try {
return responseConverter.apply(response);
} catch(Exception e) {

View File

@ -29,6 +29,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
@ -40,6 +41,7 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.lucene.uid.Versions;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
@ -714,12 +716,29 @@ public class RequestTests extends ESTestCase {
if (searchSourceBuilder == null) {
assertNull(request.entity);
} else {
BytesReference expectedBytes = XContentHelper.toXContent(searchSourceBuilder, XContentType.JSON, false);
assertEquals(XContentType.JSON.mediaType(), request.entity.getContentType().getValue());
assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(request.entity)));
assertToXContentBody(searchSourceBuilder, request.entity);
}
}
public void testSearchScroll() throws IOException {
SearchScrollRequest searchScrollRequest = new SearchScrollRequest();
searchScrollRequest.scrollId(randomAlphaOfLengthBetween(5, 10));
if (randomBoolean()) {
searchScrollRequest.scroll(randomPositiveTimeValue());
}
Request request = Request.searchScroll(searchScrollRequest);
assertEquals("GET", request.method);
assertEquals("/_search/scroll", request.endpoint);
assertEquals(0, request.params.size());
assertToXContentBody(searchScrollRequest, request.entity);
}
private static void assertToXContentBody(ToXContent expectedBody, HttpEntity actualEntity) throws IOException {
BytesReference expectedBytes = XContentHelper.toXContent(expectedBody, Request.REQUEST_BODY_CONTENT_TYPE, false);
assertEquals(XContentType.JSON.mediaType(), actualEntity.getContentType().getValue());
assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(actualEntity)));
}
public void testParams() {
final int nbParams = randomIntBetween(0, 10);
Request.Params params = Request.Params.builder();

View File

@ -33,6 +33,7 @@ import org.apache.http.entity.StringEntity;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicRequestLine;
import org.apache.http.message.BasicStatusLine;
import org.apache.http.nio.entity.NStringEntity;
import org.elasticsearch.Build;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
@ -41,21 +42,26 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.main.MainRequest;
import org.elasticsearch.action.main.MainResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchResponseSections;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.cbor.CborXContent;
import org.elasticsearch.common.xcontent.smile.SmileXContent;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentMatcher;
import org.mockito.Matchers;
import org.mockito.internal.matchers.ArrayEquals;
import org.mockito.internal.matchers.VarargMatcher;
@ -68,6 +74,7 @@ import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.client.RestClientTestUtil.randomHeaders;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.mockito.Matchers.anyMapOf;
@ -76,6 +83,8 @@ import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.anyVararg;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isNotNull;
import static org.mockito.Matchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -95,49 +104,70 @@ public class RestHighLevelClientTests extends ESTestCase {
}
public void testPingSuccessful() throws IOException {
Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
Header[] headers = randomHeaders(random(), "Header");
Response response = mock(Response.class);
when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.OK));
when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
anyObject(), anyVararg())).thenReturn(response);
assertTrue(restHighLevelClient.ping(headers));
verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
}
public void testPing404NotFound() throws IOException {
Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
Header[] headers = randomHeaders(random(), "Header");
Response response = mock(Response.class);
when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.NOT_FOUND));
when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
anyObject(), anyVararg())).thenReturn(response);
assertFalse(restHighLevelClient.ping(headers));
verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
}
public void testPingSocketTimeout() throws IOException {
Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
Header[] headers = randomHeaders(random(), "Header");
when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
anyObject(), anyVararg())).thenThrow(new SocketTimeoutException());
expectThrows(SocketTimeoutException.class, () -> restHighLevelClient.ping(headers));
verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
}
public void testInfo() throws IOException {
Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
Response response = mock(Response.class);
Header[] headers = randomHeaders(random(), "Header");
MainResponse testInfo = new MainResponse("nodeName", Version.CURRENT, new ClusterName("clusterName"), "clusterUuid",
Build.CURRENT, true);
when(response.getEntity()).thenReturn(
new StringEntity(toXContent(testInfo, XContentType.JSON, false).utf8ToString(), ContentType.APPLICATION_JSON));
when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
anyObject(), anyVararg())).thenReturn(response);
mockResponse(testInfo);
MainResponse receivedInfo = restHighLevelClient.info(headers);
assertEquals(testInfo, receivedInfo);
verify(restClient).performRequest(eq("GET"), eq("/"), eq(Collections.emptyMap()),
Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
}
public void testSearchScroll() throws IOException {
Header[] headers = randomHeaders(random(), "Header");
SearchResponse mockSearchResponse = new SearchResponse(new SearchResponseSections(SearchHits.empty(), InternalAggregations.EMPTY,
null, false, false, null, 1), randomAlphaOfLengthBetween(5, 10), 5, 5, 100, new ShardSearchFailure[0]);
mockResponse(mockSearchResponse);
SearchResponse searchResponse = restHighLevelClient.searchScroll(new SearchScrollRequest(randomAlphaOfLengthBetween(5, 10)),
headers);
assertEquals(mockSearchResponse.getScrollId(), searchResponse.getScrollId());
assertEquals(0, searchResponse.getHits().totalHits);
assertEquals(5, searchResponse.getTotalShards());
assertEquals(5, searchResponse.getSuccessfulShards());
assertEquals(100, searchResponse.getTook().getMillis());
verify(restClient).performRequest(eq("GET"), eq("/_search/scroll"), eq(Collections.emptyMap()),
isNotNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
}
private void mockResponse(ToXContent toXContent) throws IOException {
Response response = mock(Response.class);
ContentType contentType = ContentType.parse(Request.REQUEST_BODY_CONTENT_TYPE.mediaType());
String requestBody = toXContent(toXContent, Request.REQUEST_BODY_CONTENT_TYPE, false).utf8ToString();
when(response.getEntity()).thenReturn(new NStringEntity(requestBody, contentType));
when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
anyObject(), anyVararg())).thenReturn(response);
}
public void testRequestValidation() {