From b8f4c92d41411e17ec45f3c83dc81d1f12d39751 Mon Sep 17 00:00:00 2001 From: Chris Earle Date: Wed, 24 Aug 2016 19:28:32 -0400 Subject: [PATCH] Allow RestClient to send array-based headers This enables the RestClient to send array-based (multi-valued) header values, rather than only sending whatever happened to be the _last_ value of the header. --- .../org/elasticsearch/client/RestClient.java | 13 ++- .../client/RestClientBuilder.java | 4 +- .../client/RestClientIntegTests.java | 62 ++++++--------- .../client/RestClientSingleHostTests.java | 79 +++++++++---------- client/test/build.gradle | 1 + .../client/RestClientTestCase.java | 76 ++++++++++++++++++ 6 files changed, 153 insertions(+), 82 deletions(-) diff --git a/client/rest/src/main/java/org/elasticsearch/client/RestClient.java b/client/rest/src/main/java/org/elasticsearch/client/RestClient.java index 26af479f668..d2301e1e8e7 100644 --- a/client/rest/src/main/java/org/elasticsearch/client/RestClient.java +++ b/client/rest/src/main/java/org/elasticsearch/client/RestClient.java @@ -362,12 +362,17 @@ public class RestClient implements Closeable { private void setHeaders(HttpRequest httpRequest, Header[] requestHeaders) { Objects.requireNonNull(requestHeaders, "request headers must not be null"); - for (Header defaultHeader : defaultHeaders) { - httpRequest.setHeader(defaultHeader); - } + // request headers override default headers, so we don't add default headers if they exist as request headers + final Set requestNames = new HashSet<>(requestHeaders.length); for (Header requestHeader : requestHeaders) { Objects.requireNonNull(requestHeader, "request header must not be null"); - httpRequest.setHeader(requestHeader); + httpRequest.addHeader(requestHeader); + requestNames.add(requestHeader.getName()); + } + for (Header defaultHeader : defaultHeaders) { + if (requestNames.contains(defaultHeader.getName()) == false) { + httpRequest.addHeader(defaultHeader); + } } } diff --git a/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java b/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java index 4f9f379d08e..d342d59ade4 100644 --- a/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java +++ b/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java @@ -71,7 +71,9 @@ public final class RestClientBuilder { } /** - * Sets the default request headers, which will be sent along with each request + * Sets the default request headers, which will be sent along with each request. + *

+ * Request-time headers will always overwrite any default headers. * * @throws NullPointerException if {@code defaultHeaders} or any header is {@code null}. */ diff --git a/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java b/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java index e7d7852de04..9c5c50946d8 100644 --- a/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java +++ b/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java @@ -19,8 +19,6 @@ package org.elasticsearch.client; -import com.carrotsearch.randomizedtesting.generators.RandomInts; -import com.carrotsearch.randomizedtesting.generators.RandomStrings; import com.sun.net.httpserver.Headers; import com.sun.net.httpserver.HttpContext; import com.sun.net.httpserver.HttpExchange; @@ -28,10 +26,8 @@ import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; import org.apache.http.Consts; import org.apache.http.Header; -import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.entity.StringEntity; -import org.apache.http.message.BasicHeader; import org.apache.http.util.EntityUtils; import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.AfterClass; @@ -83,13 +79,8 @@ public class RestClientIntegTests extends RestClientTestCase { for (int statusCode : getAllStatusCodes()) { createStatusCodeContext(httpServer, statusCode); } - int numHeaders = RandomInts.randomIntBetween(getRandom(), 0, 3); - defaultHeaders = new Header[numHeaders]; - for (int i = 0; i < numHeaders; i++) { - String headerName = "Header-default" + (getRandom().nextBoolean() ? i : ""); - String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); - defaultHeaders[i] = new BasicHeader(headerName, headerValue); - } + int numHeaders = randomIntBetween(0, 5); + defaultHeaders = generateHeaders("Header-default", "Header-array", numHeaders); restClient = RestClient.builder(new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())) .setDefaultHeaders(defaultHeaders).build(); } @@ -148,44 +139,43 @@ public class RestClientIntegTests extends RestClientTestCase { */ public void testHeaders() throws IOException { for (String method : getHttpMethods()) { - Set standardHeaders = new HashSet<>( - Arrays.asList("Connection", "Host", "User-agent", "Date")); + final Set standardHeaders = new HashSet<>(Arrays.asList("Connection", "Host", "User-agent", "Date")); if (method.equals("HEAD") == false) { standardHeaders.add("Content-length"); } - int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5); - Map expectedHeaders = new HashMap<>(); - for (Header defaultHeader : defaultHeaders) { - expectedHeaders.put(defaultHeader.getName(), defaultHeader.getValue()); - } - Header[] headers = new Header[numHeaders]; - for (int i = 0; i < numHeaders; i++) { - String headerName = "Header" + (getRandom().nextBoolean() ? i : ""); - String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); - headers[i] = new BasicHeader(headerName, headerValue); - expectedHeaders.put(headerName, headerValue); - } - int statusCode = randomStatusCode(getRandom()); + final int numHeaders = randomIntBetween(1, 5); + final Header[] headers = generateHeaders("Header", "Header-array", numHeaders); + final Map> expectedHeaders = new HashMap<>(); + + addHeaders(expectedHeaders, defaultHeaders, headers); + + final int statusCode = randomStatusCode(getRandom()); Response esResponse; try { - esResponse = restClient.performRequest(method, "/" + statusCode, Collections.emptyMap(), - (HttpEntity)null, headers); + esResponse = restClient.performRequest(method, "/" + statusCode, Collections.emptyMap(), headers); } catch(ResponseException e) { esResponse = e.getResponse(); } assertThat(esResponse.getStatusLine().getStatusCode(), equalTo(statusCode)); - for (Header responseHeader : esResponse.getHeaders()) { - if (responseHeader.getName().startsWith("Header")) { - String headerValue = expectedHeaders.remove(responseHeader.getName()); - assertNotNull("found response header [" + responseHeader.getName() + "] that wasn't originally sent", headerValue); + for (final Header responseHeader : esResponse.getHeaders()) { + final String name = responseHeader.getName(); + final String value = responseHeader.getValue(); + if (name.startsWith("Header")) { + final List values = expectedHeaders.get(name); + assertNotNull("found response header [" + name + "] that wasn't originally sent: " + value, values); + assertTrue("found incorrect response header [" + name + "]: " + value, values.remove(value)); + + // we've collected them all + if (values.isEmpty()) { + expectedHeaders.remove(name); + } } else { - assertTrue("unknown header was returned " + responseHeader.getName(), - standardHeaders.remove(responseHeader.getName())); + assertTrue("unknown header was returned " + name, standardHeaders.remove(name)); } } - assertEquals("some headers that were sent weren't returned: " + expectedHeaders, 0, expectedHeaders.size()); - assertEquals("some expected standard headers weren't returned: " + standardHeaders, 0, standardHeaders.size()); + assertTrue("some headers that were sent weren't returned: " + expectedHeaders, expectedHeaders.isEmpty()); + assertTrue("some expected standard headers weren't returned: " + standardHeaders, standardHeaders.isEmpty()); } } diff --git a/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java b/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java index a6ae30b01e8..92e2b0da971 100644 --- a/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java +++ b/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java @@ -19,8 +19,6 @@ package org.elasticsearch.client; -import com.carrotsearch.randomizedtesting.generators.RandomInts; -import com.carrotsearch.randomizedtesting.generators.RandomStrings; import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.HttpEntityEnclosingRequest; @@ -41,7 +39,6 @@ import org.apache.http.concurrent.FutureCallback; import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.entity.StringEntity; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; -import org.apache.http.message.BasicHeader; import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicStatusLine; import org.apache.http.nio.protocol.HttpAsyncRequestProducer; @@ -58,7 +55,10 @@ import java.net.URI; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Future; import static org.elasticsearch.client.RestClientTestUtil.getAllErrorStatusCodes; @@ -132,13 +132,8 @@ public class RestClientSingleHostTests extends RestClientTestCase { }); - int numHeaders = RandomInts.randomIntBetween(getRandom(), 0, 3); - defaultHeaders = new Header[numHeaders]; - for (int i = 0; i < numHeaders; i++) { - String headerName = "Header-default" + (getRandom().nextBoolean() ? i : ""); - String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); - defaultHeaders[i] = new BasicHeader(headerName, headerValue); - } + int numHeaders = randomIntBetween(0, 3); + defaultHeaders = generateHeaders("Header-default", "Header-array", numHeaders); httpHost = new HttpHost("localhost", 9200); failureListener = new HostsTrackingFailureListener(); restClient = new RestClient(httpClient, 10000, defaultHeaders, new HttpHost[]{httpHost}, null, failureListener); @@ -333,20 +328,13 @@ public class RestClientSingleHostTests extends RestClientTestCase { */ public void testHeaders() throws IOException { for (String method : getHttpMethods()) { - Map expectedHeaders = new HashMap<>(); - for (Header defaultHeader : defaultHeaders) { - expectedHeaders.put(defaultHeader.getName(), defaultHeader.getValue()); - } - int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5); - Header[] headers = new Header[numHeaders]; - for (int i = 0; i < numHeaders; i++) { - String headerName = "Header" + (getRandom().nextBoolean() ? i : ""); - String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); - headers[i] = new BasicHeader(headerName, headerValue); - expectedHeaders.put(headerName, headerValue); - } + final int numHeaders = randomIntBetween(1, 5); + final Header[] headers = generateHeaders("Header", null, numHeaders); + final Map> expectedHeaders = new HashMap<>(); - int statusCode = randomStatusCode(getRandom()); + addHeaders(expectedHeaders, defaultHeaders, headers); + + final int statusCode = randomStatusCode(getRandom()); Response esResponse; try { esResponse = restClient.performRequest(method, "/" + statusCode, headers); @@ -355,10 +343,18 @@ public class RestClientSingleHostTests extends RestClientTestCase { } assertThat(esResponse.getStatusLine().getStatusCode(), equalTo(statusCode)); for (Header responseHeader : esResponse.getHeaders()) { - String headerValue = expectedHeaders.remove(responseHeader.getName()); - assertNotNull("found response header [" + responseHeader.getName() + "] that wasn't originally sent", headerValue); + final String name = responseHeader.getName(); + final String value = responseHeader.getValue(); + final List values = expectedHeaders.get(name); + assertNotNull("found response header [" + name + "] that wasn't originally sent: " + value, values); + assertTrue("found incorrect response header [" + name + "]: " + value, values.remove(value)); + + // we've collected them all + if (values.isEmpty()) { + expectedHeaders.remove(name); + } } - assertEquals("some headers that were sent weren't returned " + expectedHeaders, 0, expectedHeaders.size()); + assertTrue("some headers that were sent weren't returned " + expectedHeaders, expectedHeaders.isEmpty()); } } @@ -368,11 +364,11 @@ public class RestClientSingleHostTests extends RestClientTestCase { Map params = Collections.emptyMap(); boolean hasParams = randomBoolean(); if (hasParams) { - int numParams = RandomInts.randomIntBetween(getRandom(), 1, 3); + int numParams = randomIntBetween(1, 3); params = new HashMap<>(numParams); for (int i = 0; i < numParams; i++) { String paramKey = "param-" + i; - String paramValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); + String paramValue = randomAsciiOfLengthBetween(3, 10); params.put(paramKey, paramValue); uriBuilder.addParameter(paramKey, paramValue); } @@ -412,24 +408,24 @@ public class RestClientSingleHostTests extends RestClientTestCase { HttpEntity entity = null; boolean hasBody = request instanceof HttpEntityEnclosingRequest && getRandom().nextBoolean(); if (hasBody) { - entity = new StringEntity(RandomStrings.randomAsciiOfLengthBetween(getRandom(), 10, 100)); + entity = new StringEntity(randomAsciiOfLengthBetween(10, 100)); ((HttpEntityEnclosingRequest) request).setEntity(entity); } Header[] headers = new Header[0]; - for (Header defaultHeader : defaultHeaders) { - //default headers are expected but not sent for each request - request.setHeader(defaultHeader); + final int numHeaders = randomIntBetween(1, 5); + final Set uniqueNames = new HashSet<>(numHeaders); + if (randomBoolean()) { + headers = generateHeaders("Header", "Header-array", numHeaders); + for (Header header : headers) { + request.addHeader(header); + uniqueNames.add(header.getName()); + } } - if (getRandom().nextBoolean()) { - int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5); - headers = new Header[numHeaders]; - for (int i = 0; i < numHeaders; i++) { - String headerName = "Header" + (getRandom().nextBoolean() ? i : ""); - String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10); - BasicHeader basicHeader = new BasicHeader(headerName, headerValue); - headers[i] = basicHeader; - request.setHeader(basicHeader); + for (Header defaultHeader : defaultHeaders) { + // request level headers override default headers + if (uniqueNames.contains(defaultHeader.getName()) == false) { + request.addHeader(defaultHeader); } } @@ -459,4 +455,5 @@ public class RestClientSingleHostTests extends RestClientTestCase { throw new UnsupportedOperationException(); } } + } diff --git a/client/test/build.gradle b/client/test/build.gradle index 05d044504ec..a7ffe79ac5c 100644 --- a/client/test/build.gradle +++ b/client/test/build.gradle @@ -30,6 +30,7 @@ install.enabled = false uploadArchives.enabled = false dependencies { + compile "org.apache.httpcomponents:httpcore:${versions.httpcore}" compile "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}" compile "junit:junit:${versions.junit}" compile "org.hamcrest:hamcrest-all:${versions.hamcrest}" diff --git a/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java b/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java index 8c506beb5ac..4296932a002 100644 --- a/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java +++ b/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java @@ -31,6 +31,15 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakZombies; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + @TestMethodProviders({ JUnit3MethodProvider.class }) @@ -43,4 +52,71 @@ import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; @TimeoutSuite(millis = 2 * 60 * 60 * 1000) public abstract class RestClientTestCase extends RandomizedTest { + /** + * Create the specified number of {@link Header}s. + *

+ * Generated header names will be the {@code baseName} plus its index or, rarely, the {@code arrayName} if it's supplied. + * + * @param baseName The base name to use for all headers. + * @param arrayName The optional ({@code null}able) array name to use randomly. + * @param headers The number of headers to create. + * @return Never {@code null}. + */ + protected static Header[] generateHeaders(final String baseName, final String arrayName, final int headers) { + final Header[] generated = new Header[headers]; + for (int i = 0; i < headers; i++) { + String headerName = baseName + i; + if (arrayName != null && rarely()) { + headerName = arrayName; + } + + generated[i] = new BasicHeader(headerName, randomAsciiOfLengthBetween(3, 10)); + } + return generated; + } + + /** + * Create a new {@link List} within the {@code map} if none exists for {@code name} or append to the existing list. + * + * @param map The map to manipulate. + * @param name The name to create/append the list for. + * @param value The value to add. + */ + private static void createOrAppendList(final Map> map, final String name, final String value) { + List values = map.get(name); + + if (values == null) { + values = new ArrayList<>(); + map.put(name, values); + } + + values.add(value); + } + + /** + * Add the {@code headers} to the {@code map} so that related tests can more easily assert that they exist. + *

+ * If both the {@code defaultHeaders} and {@code headers} contain the same {@link Header}, based on its + * {@linkplain Header#getName() name}, then this will only use the {@code Header}(s) from {@code headers}. + * + * @param map The map to build with name/value(s) pairs. + * @param defaultHeaders The headers to add to the map representing default headers. + * @param headers The headers to add to the map representing request-level headers. + * @see #createOrAppendList(Map, String, String) + */ + protected static void addHeaders(final Map> map, final Header[] defaultHeaders, final Header[] headers) { + final Set uniqueHeaders = new HashSet<>(); + for (final Header header : headers) { + final String name = header.getName(); + createOrAppendList(map, name, header.getValue()); + uniqueHeaders.add(name); + } + for (final Header defaultHeader : defaultHeaders) { + final String name = defaultHeader.getName(); + if (uniqueHeaders.contains(name) == false) { + createOrAppendList(map, name, defaultHeader.getValue()); + } + } + } + }