Support Preemptive Authentication with RestClient (#21336)

This adds the necessary `AuthCache` needed to support preemptive authorization. By adding every host to the cache, the automatically added `RequestAuthCache` interceptor will add credentials on the first pass rather than waiting to do it after _each_ anonymous request is rejected (thus always sending everything twice when basic auth is required).
This commit is contained in:
Chris Earle 2017-01-24 11:34:05 -05:00 committed by GitHub
parent 47c0e13a3b
commit f0f75b187a
6 changed files with 157 additions and 41 deletions

View File

@ -25,6 +25,7 @@ import org.apache.http.HttpEntity;
import org.apache.http.HttpHost; import org.apache.http.HttpHost;
import org.apache.http.HttpRequest; import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse; import org.apache.http.HttpResponse;
import org.apache.http.client.AuthCache;
import org.apache.http.client.ClientProtocolException; import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpHead; import org.apache.http.client.methods.HttpHead;
@ -34,8 +35,11 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut; import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.methods.HttpTrace; import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder; import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback; import org.apache.http.concurrent.FutureCallback;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.client.BasicAuthCache;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.nio.client.methods.HttpAsyncMethods; import org.apache.http.nio.client.methods.HttpAsyncMethods;
import org.apache.http.nio.protocol.HttpAsyncRequestProducer; import org.apache.http.nio.protocol.HttpAsyncRequestProducer;
@ -92,7 +96,7 @@ public class RestClient implements Closeable {
private final long maxRetryTimeoutMillis; private final long maxRetryTimeoutMillis;
private final String pathPrefix; private final String pathPrefix;
private final AtomicInteger lastHostIndex = new AtomicInteger(0); private final AtomicInteger lastHostIndex = new AtomicInteger(0);
private volatile Set<HttpHost> hosts; private volatile HostTuple<Set<HttpHost>> hostTuple;
private final ConcurrentMap<HttpHost, DeadHostState> blacklist = new ConcurrentHashMap<>(); private final ConcurrentMap<HttpHost, DeadHostState> blacklist = new ConcurrentHashMap<>();
private final FailureListener failureListener; private final FailureListener failureListener;
@ -122,11 +126,13 @@ public class RestClient implements Closeable {
throw new IllegalArgumentException("hosts must not be null nor empty"); throw new IllegalArgumentException("hosts must not be null nor empty");
} }
Set<HttpHost> httpHosts = new HashSet<>(); Set<HttpHost> httpHosts = new HashSet<>();
AuthCache authCache = new BasicAuthCache();
for (HttpHost host : hosts) { for (HttpHost host : hosts) {
Objects.requireNonNull(host, "host cannot be null"); Objects.requireNonNull(host, "host cannot be null");
httpHosts.add(host); httpHosts.add(host);
authCache.put(host, new BasicScheme());
} }
this.hosts = Collections.unmodifiableSet(httpHosts); this.hostTuple = new HostTuple<>(Collections.unmodifiableSet(httpHosts), authCache);
this.blacklist.clear(); this.blacklist.clear();
} }
@ -315,19 +321,22 @@ public class RestClient implements Closeable {
setHeaders(request, headers); setHeaders(request, headers);
FailureTrackingResponseListener failureTrackingResponseListener = new FailureTrackingResponseListener(responseListener); FailureTrackingResponseListener failureTrackingResponseListener = new FailureTrackingResponseListener(responseListener);
long startTime = System.nanoTime(); long startTime = System.nanoTime();
performRequestAsync(startTime, nextHost().iterator(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, performRequestAsync(startTime, nextHost(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
failureTrackingResponseListener); failureTrackingResponseListener);
} }
private void performRequestAsync(final long startTime, final Iterator<HttpHost> hosts, final HttpRequestBase request, private void performRequestAsync(final long startTime, final HostTuple<Iterator<HttpHost>> hostTuple, final HttpRequestBase request,
final Set<Integer> ignoreErrorCodes, final Set<Integer> ignoreErrorCodes,
final HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory, final HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory,
final FailureTrackingResponseListener listener) { final FailureTrackingResponseListener listener) {
final HttpHost host = hosts.next(); final HttpHost host = hostTuple.hosts.next();
//we stream the request body if the entity allows for it //we stream the request body if the entity allows for it
HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request); final HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer = httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer(); final HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer =
client.execute(requestProducer, asyncResponseConsumer, new FutureCallback<HttpResponse>() { httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
final HttpClientContext context = HttpClientContext.create();
context.setAuthCache(hostTuple.authCache);
client.execute(requestProducer, asyncResponseConsumer, context, new FutureCallback<HttpResponse>() {
@Override @Override
public void completed(HttpResponse httpResponse) { public void completed(HttpResponse httpResponse) {
try { try {
@ -366,7 +375,7 @@ public class RestClient implements Closeable {
} }
private void retryIfPossible(Exception exception) { private void retryIfPossible(Exception exception) {
if (hosts.hasNext()) { if (hostTuple.hosts.hasNext()) {
//in case we are retrying, check whether maxRetryTimeout has been reached //in case we are retrying, check whether maxRetryTimeout has been reached
long timeElapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime); long timeElapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
long timeout = maxRetryTimeoutMillis - timeElapsedMillis; long timeout = maxRetryTimeoutMillis - timeElapsedMillis;
@ -377,7 +386,7 @@ public class RestClient implements Closeable {
} else { } else {
listener.trackFailure(exception); listener.trackFailure(exception);
request.reset(); request.reset();
performRequestAsync(startTime, hosts, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener); performRequestAsync(startTime, hostTuple, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
} }
} else { } else {
listener.onDefinitiveFailure(exception); listener.onDefinitiveFailure(exception);
@ -415,17 +424,18 @@ public class RestClient implements Closeable {
* The iterator returned will never be empty. In case there are no healthy hosts available, or dead ones to be be retried, * The iterator returned will never be empty. In case there are no healthy hosts available, or dead ones to be be retried,
* one dead host gets returned so that it can be retried. * one dead host gets returned so that it can be retried.
*/ */
private Iterable<HttpHost> nextHost() { private HostTuple<Iterator<HttpHost>> nextHost() {
final HostTuple<Set<HttpHost>> hostTuple = this.hostTuple;
Collection<HttpHost> nextHosts = Collections.emptySet(); Collection<HttpHost> nextHosts = Collections.emptySet();
do { do {
Set<HttpHost> filteredHosts = new HashSet<>(hosts); Set<HttpHost> filteredHosts = new HashSet<>(hostTuple.hosts);
for (Map.Entry<HttpHost, DeadHostState> entry : blacklist.entrySet()) { for (Map.Entry<HttpHost, DeadHostState> entry : blacklist.entrySet()) {
if (System.nanoTime() - entry.getValue().getDeadUntilNanos() < 0) { if (System.nanoTime() - entry.getValue().getDeadUntilNanos() < 0) {
filteredHosts.remove(entry.getKey()); filteredHosts.remove(entry.getKey());
} }
} }
if (filteredHosts.isEmpty()) { if (filteredHosts.isEmpty()) {
//last resort: if there are no good hosts to use, return a single dead one, the one that's closest to being retried //last resort: if there are no good host to use, return a single dead one, the one that's closest to being retried
List<Map.Entry<HttpHost, DeadHostState>> sortedHosts = new ArrayList<>(blacklist.entrySet()); List<Map.Entry<HttpHost, DeadHostState>> sortedHosts = new ArrayList<>(blacklist.entrySet());
if (sortedHosts.size() > 0) { if (sortedHosts.size() > 0) {
Collections.sort(sortedHosts, new Comparator<Map.Entry<HttpHost, DeadHostState>>() { Collections.sort(sortedHosts, new Comparator<Map.Entry<HttpHost, DeadHostState>>() {
@ -444,7 +454,7 @@ public class RestClient implements Closeable {
nextHosts = rotatedHosts; nextHosts = rotatedHosts;
} }
} while(nextHosts.isEmpty()); } while(nextHosts.isEmpty());
return nextHosts; return new HostTuple<>(nextHosts.iterator(), hostTuple.authCache);
} }
/** /**
@ -686,4 +696,18 @@ public class RestClient implements Closeable {
} }
} }
/**
* {@code HostTuple} enables the {@linkplain HttpHost}s and {@linkplain AuthCache} to be set together in a thread
* safe, volatile way.
*/
private static class HostTuple<T> {
public final T hosts;
public final AuthCache authCache;
public HostTuple(final T hosts, final AuthCache authCache) {
this.hosts = hosts;
this.authCache = authCache;
}
}
} }

View File

@ -26,8 +26,10 @@ import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion; import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine; import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback; import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine; import org.apache.http.message.BasicStatusLine;
@ -73,13 +75,15 @@ public class RestClientMultipleHostsTests extends RestClientTestCase {
public void createRestClient() throws IOException { public void createRestClient() throws IOException {
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class); CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class), when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() { any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override @Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable { public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0]; HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest(); HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
HttpHost httpHost = requestProducer.getTarget(); HttpHost httpHost = requestProducer.getTarget();
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2]; HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
//return the desired status code or exception depending on the path //return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) { if (request.getURI().getPath().equals("/soe")) {
futureCallback.failed(new SocketTimeoutException(httpHost.toString())); futureCallback.failed(new SocketTimeoutException(httpHost.toString()));

View File

@ -26,7 +26,11 @@ import com.sun.net.httpserver.HttpServer;
import org.apache.http.Consts; import org.apache.http.Consts;
import org.apache.http.Header; import org.apache.http.Header;
import org.apache.http.HttpHost; import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.entity.StringEntity; import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.util.EntityUtils; import org.apache.http.util.EntityUtils;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
import org.elasticsearch.mocksocket.MockHttpServer; import org.elasticsearch.mocksocket.MockHttpServer;
@ -48,7 +52,10 @@ import java.util.Set;
import static org.elasticsearch.client.RestClientTestUtil.getAllStatusCodes; import static org.elasticsearch.client.RestClientTestUtil.getAllStatusCodes;
import static org.elasticsearch.client.RestClientTestUtil.getHttpMethods; import static org.elasticsearch.client.RestClientTestUtil.getHttpMethods;
import static org.elasticsearch.client.RestClientTestUtil.randomStatusCode; import static org.elasticsearch.client.RestClientTestUtil.randomStatusCode;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
/** /**
@ -66,22 +73,10 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
@BeforeClass @BeforeClass
public static void startHttpServer() throws Exception { public static void startHttpServer() throws Exception {
String pathPrefixWithoutLeadingSlash; pathPrefix = randomBoolean() ? "/testPathPrefix/" + randomAsciiOfLengthBetween(1, 5) : "";
if (randomBoolean()) {
pathPrefixWithoutLeadingSlash = "testPathPrefix/" + randomAsciiOfLengthBetween(1, 5);
pathPrefix = "/" + pathPrefixWithoutLeadingSlash;
} else {
pathPrefix = pathPrefixWithoutLeadingSlash = "";
}
httpServer = createHttpServer(); httpServer = createHttpServer();
defaultHeaders = RestClientTestUtil.randomHeaders(getRandom(), "Header-default"); defaultHeaders = RestClientTestUtil.randomHeaders(getRandom(), "Header-default");
RestClientBuilder restClientBuilder = RestClient.builder( restClient = createRestClient(false, true);
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
restClientBuilder.setPathPrefix((randomBoolean() ? "/" : "") + pathPrefixWithoutLeadingSlash);
}
restClient = restClientBuilder.build();
} }
private static HttpServer createHttpServer() throws Exception { private static HttpServer createHttpServer() throws Exception {
@ -129,6 +124,35 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
} }
} }
private static RestClient createRestClient(final boolean useAuth, final boolean usePreemptiveAuth) {
// provide the username/password for every request
final BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("user", "pass"));
final RestClientBuilder restClientBuilder = RestClient.builder(
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
// sometimes cut off the leading slash
restClientBuilder.setPathPrefix(randomBoolean() ? pathPrefix.substring(1) : pathPrefix);
}
if (useAuth) {
restClientBuilder.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(final HttpAsyncClientBuilder httpClientBuilder) {
if (usePreemptiveAuth == false) {
// disable preemptive auth by ignoring any authcache
httpClientBuilder.disableAuthCaching();
}
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
});
}
return restClientBuilder.build();
}
@AfterClass @AfterClass
public static void stopHttpServers() throws IOException { public static void stopHttpServers() throws IOException {
restClient.close(); restClient.close();
@ -159,7 +183,7 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
assertEquals(method, esResponse.getRequestLine().getMethod()); assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode()); assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri()); assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertHeaders(defaultHeaders, requestHeaders, esResponse.getHeaders(), standardHeaders); assertHeaders(defaultHeaders, requestHeaders, esResponse.getHeaders(), standardHeaders);
for (final Header responseHeader : esResponse.getHeaders()) { for (final Header responseHeader : esResponse.getHeaders()) {
String name = responseHeader.getName(); String name = responseHeader.getName();
@ -189,7 +213,41 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
bodyTest("GET"); bodyTest("GET");
} }
private void bodyTest(String method) throws IOException { /**
* Verify that credentials are sent on the first request with preemptive auth enabled (default when provided with credentials).
*/
public void testPreemptiveAuthEnabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
try (final RestClient restClient = createRestClient(true, true)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);
assertThat(response.getHeader("Authorization"), startsWith("Basic"));
}
}
}
/**
* Verify that credentials are <em>not</em> sent on the first request with preemptive auth disabled.
*/
public void testPreemptiveAuthDisabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
try (final RestClient restClient = createRestClient(true, false)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);
assertThat(response.getHeader("Authorization"), nullValue());
}
}
}
private Response bodyTest(final String method) throws IOException {
return bodyTest(restClient, method);
}
private Response bodyTest(final RestClient restClient, final String method) throws IOException {
String requestBody = "{ \"field\": \"value\" }"; String requestBody = "{ \"field\": \"value\" }";
StringEntity entity = new StringEntity(requestBody); StringEntity entity = new StringEntity(requestBody);
int statusCode = randomStatusCode(getRandom()); int statusCode = randomStatusCode(getRandom());
@ -201,7 +259,9 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
} }
assertEquals(method, esResponse.getRequestLine().getMethod()); assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode()); assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri()); assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(requestBody, EntityUtils.toString(esResponse.getEntity())); assertEquals(requestBody, EntityUtils.toString(esResponse.getEntity()));
return esResponse;
} }
} }

View File

@ -34,10 +34,12 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut; import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpTrace; import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder; import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback; import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.entity.StringEntity; import org.apache.http.entity.StringEntity;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine; import org.apache.http.message.BasicStatusLine;
@ -96,11 +98,13 @@ public class RestClientSingleHostTests extends RestClientTestCase {
public void createRestClient() throws IOException { public void createRestClient() throws IOException {
httpClient = mock(CloseableHttpAsyncClient.class); httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class), when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() { any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override @Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable { public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0]; HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2]; HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest(); HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
//return the desired status code or exception depending on the path //return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) { if (request.getURI().getPath().equals("/soe")) {
@ -156,7 +160,7 @@ public class RestClientSingleHostTests extends RestClientTestCase {
for (String httpMethod : getHttpMethods()) { for (String httpMethod : getHttpMethods()) {
HttpUriRequest expectedRequest = performRandomRequest(httpMethod); HttpUriRequest expectedRequest = performRandomRequest(httpMethod);
verify(httpClient, times(++times)).<HttpResponse>execute(requestArgumentCaptor.capture(), verify(httpClient, times(++times)).<HttpResponse>execute(requestArgumentCaptor.capture(),
any(HttpAsyncResponseConsumer.class), any(FutureCallback.class)); any(HttpAsyncResponseConsumer.class), any(HttpClientContext.class), any(FutureCallback.class));
HttpUriRequest actualRequest = (HttpUriRequest)requestArgumentCaptor.getValue().generateRequest(); HttpUriRequest actualRequest = (HttpUriRequest)requestArgumentCaptor.getValue().generateRequest();
assertEquals(expectedRequest.getURI(), actualRequest.getURI()); assertEquals(expectedRequest.getURI(), actualRequest.getURI());
assertEquals(expectedRequest.getClass(), actualRequest.getClass()); assertEquals(expectedRequest.getClass(), actualRequest.getClass());

View File

@ -81,6 +81,29 @@ RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200))
.build(); .build();
-------------------------------------------------- --------------------------------------------------
You can disable Preemptive Authentication, which means that every request will be sent without
authorization headers to see if it is accepted and, upon receiving a HTTP 401 response, it will
resend the exact same request with the basic authentication header. If you wish to do this, then
you can do so by disabling it via the `HttpAsyncClientBuilder`:
[source,java]
--------------------------------------------------
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY,
new UsernamePasswordCredentials("user", "password"));
RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200))
.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
// disable preemptive authentication
httpClientBuilder.disableAuthCaching();
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
})
.build();
--------------------------------------------------
=== Encrypted communication === Encrypted communication
Encrypted communication can also be configured through the Encrypted communication can also be configured through the

View File

@ -26,6 +26,7 @@ import org.apache.http.HttpHost;
import org.apache.http.HttpResponse; import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion; import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine; import org.apache.http.StatusLine;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback; import org.apache.http.concurrent.FutureCallback;
import org.apache.http.entity.ContentType; import org.apache.http.entity.ContentType;
import org.apache.http.entity.InputStreamEntity; import org.apache.http.entity.InputStreamEntity;
@ -430,11 +431,11 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
ContentTooLongException tooLong = new ContentTooLongException("too long!"); ContentTooLongException tooLong = new ContentTooLongException("too long!");
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class); CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class), when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).then(new Answer<Future<HttpResponse>>() { any(HttpClientContext.class), any(FutureCallback.class))).then(new Answer<Future<HttpResponse>>() {
@Override @Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable { public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HeapBufferedAsyncResponseConsumer consumer = (HeapBufferedAsyncResponseConsumer) invocationOnMock.getArguments()[1]; HeapBufferedAsyncResponseConsumer consumer = (HeapBufferedAsyncResponseConsumer) invocationOnMock.getArguments()[1];
FutureCallback callback = (FutureCallback) invocationOnMock.getArguments()[2]; FutureCallback callback = (FutureCallback) invocationOnMock.getArguments()[3];
assertEquals(new ByteSizeValue(100, ByteSizeUnit.MB).bytesAsInt(), consumer.getBufferLimit()); assertEquals(new ByteSizeValue(100, ByteSizeUnit.MB).bytesAsInt(), consumer.getBufferLimit());
callback.failed(tooLong); callback.failed(tooLong);
return null; return null;
@ -495,7 +496,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class); CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class), when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() { any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
int responseCount = 0; int responseCount = 0;
@ -504,7 +505,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
// Throw away the current thread context to simulate running async httpclient's thread pool // Throw away the current thread context to simulate running async httpclient's thread pool
threadPool.getThreadContext().stashContext(); threadPool.getThreadContext().stashContext();
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0]; HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2]; FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
HttpEntityEnclosingRequest request = (HttpEntityEnclosingRequest)requestProducer.generateRequest(); HttpEntityEnclosingRequest request = (HttpEntityEnclosingRequest)requestProducer.generateRequest();
URL resource = resources[responseCount]; URL resource = resources[responseCount];
String path = paths[responseCount++]; String path = paths[responseCount++];