Support Preemptive Authentication with RestClient (#21336)

This adds the necessary `AuthCache` needed to support preemptive authorization. By adding every host to the cache, the automatically added `RequestAuthCache` interceptor will add credentials on the first pass rather than waiting to do it after _each_ anonymous request is rejected (thus always sending everything twice when basic auth is required).
This commit is contained in:
Chris Earle 2017-01-24 11:34:05 -05:00 committed by GitHub
parent 47c0e13a3b
commit f0f75b187a
6 changed files with 157 additions and 41 deletions

View File

@ -25,6 +25,7 @@ import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.AuthCache;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpHead;
@ -34,8 +35,11 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.client.BasicAuthCache;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.nio.client.methods.HttpAsyncMethods;
import org.apache.http.nio.protocol.HttpAsyncRequestProducer;
@ -92,7 +96,7 @@ public class RestClient implements Closeable {
private final long maxRetryTimeoutMillis;
private final String pathPrefix;
private final AtomicInteger lastHostIndex = new AtomicInteger(0);
private volatile Set<HttpHost> hosts;
private volatile HostTuple<Set<HttpHost>> hostTuple;
private final ConcurrentMap<HttpHost, DeadHostState> blacklist = new ConcurrentHashMap<>();
private final FailureListener failureListener;
@ -122,11 +126,13 @@ public class RestClient implements Closeable {
throw new IllegalArgumentException("hosts must not be null nor empty");
}
Set<HttpHost> httpHosts = new HashSet<>();
AuthCache authCache = new BasicAuthCache();
for (HttpHost host : hosts) {
Objects.requireNonNull(host, "host cannot be null");
httpHosts.add(host);
authCache.put(host, new BasicScheme());
}
this.hosts = Collections.unmodifiableSet(httpHosts);
this.hostTuple = new HostTuple<>(Collections.unmodifiableSet(httpHosts), authCache);
this.blacklist.clear();
}
@ -315,19 +321,22 @@ public class RestClient implements Closeable {
setHeaders(request, headers);
FailureTrackingResponseListener failureTrackingResponseListener = new FailureTrackingResponseListener(responseListener);
long startTime = System.nanoTime();
performRequestAsync(startTime, nextHost().iterator(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
performRequestAsync(startTime, nextHost(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
failureTrackingResponseListener);
}
private void performRequestAsync(final long startTime, final Iterator<HttpHost> hosts, final HttpRequestBase request,
private void performRequestAsync(final long startTime, final HostTuple<Iterator<HttpHost>> hostTuple, final HttpRequestBase request,
final Set<Integer> ignoreErrorCodes,
final HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory,
final FailureTrackingResponseListener listener) {
final HttpHost host = hosts.next();
final HttpHost host = hostTuple.hosts.next();
//we stream the request body if the entity allows for it
HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer = httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
client.execute(requestProducer, asyncResponseConsumer, new FutureCallback<HttpResponse>() {
final HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
final HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer =
httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
final HttpClientContext context = HttpClientContext.create();
context.setAuthCache(hostTuple.authCache);
client.execute(requestProducer, asyncResponseConsumer, context, new FutureCallback<HttpResponse>() {
@Override
public void completed(HttpResponse httpResponse) {
try {
@ -366,7 +375,7 @@ public class RestClient implements Closeable {
}
private void retryIfPossible(Exception exception) {
if (hosts.hasNext()) {
if (hostTuple.hosts.hasNext()) {
//in case we are retrying, check whether maxRetryTimeout has been reached
long timeElapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
long timeout = maxRetryTimeoutMillis - timeElapsedMillis;
@ -377,7 +386,7 @@ public class RestClient implements Closeable {
} else {
listener.trackFailure(exception);
request.reset();
performRequestAsync(startTime, hosts, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
performRequestAsync(startTime, hostTuple, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
}
} else {
listener.onDefinitiveFailure(exception);
@ -415,17 +424,18 @@ public class RestClient implements Closeable {
* The iterator returned will never be empty. In case there are no healthy hosts available, or dead ones to be be retried,
* one dead host gets returned so that it can be retried.
*/
private Iterable<HttpHost> nextHost() {
private HostTuple<Iterator<HttpHost>> nextHost() {
final HostTuple<Set<HttpHost>> hostTuple = this.hostTuple;
Collection<HttpHost> nextHosts = Collections.emptySet();
do {
Set<HttpHost> filteredHosts = new HashSet<>(hosts);
Set<HttpHost> filteredHosts = new HashSet<>(hostTuple.hosts);
for (Map.Entry<HttpHost, DeadHostState> entry : blacklist.entrySet()) {
if (System.nanoTime() - entry.getValue().getDeadUntilNanos() < 0) {
filteredHosts.remove(entry.getKey());
}
}
if (filteredHosts.isEmpty()) {
//last resort: if there are no good hosts to use, return a single dead one, the one that's closest to being retried
//last resort: if there are no good host to use, return a single dead one, the one that's closest to being retried
List<Map.Entry<HttpHost, DeadHostState>> sortedHosts = new ArrayList<>(blacklist.entrySet());
if (sortedHosts.size() > 0) {
Collections.sort(sortedHosts, new Comparator<Map.Entry<HttpHost, DeadHostState>>() {
@ -444,7 +454,7 @@ public class RestClient implements Closeable {
nextHosts = rotatedHosts;
}
} while(nextHosts.isEmpty());
return nextHosts;
return new HostTuple<>(nextHosts.iterator(), hostTuple.authCache);
}
/**
@ -686,4 +696,18 @@ public class RestClient implements Closeable {
}
}
/**
* {@code HostTuple} enables the {@linkplain HttpHost}s and {@linkplain AuthCache} to be set together in a thread
* safe, volatile way.
*/
private static class HostTuple<T> {
public final T hosts;
public final AuthCache authCache;
public HostTuple(final T hosts, final AuthCache authCache) {
this.hosts = hosts;
this.authCache = authCache;
}
}
}

View File

@ -26,8 +26,10 @@ import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine;
@ -73,13 +75,15 @@ public class RestClientMultipleHostsTests extends RestClientTestCase {
public void createRestClient() throws IOException {
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
HttpHost httpHost = requestProducer.getTarget();
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
//return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) {
futureCallback.failed(new SocketTimeoutException(httpHost.toString()));

View File

@ -26,7 +26,11 @@ import com.sun.net.httpserver.HttpServer;
import org.apache.http.Consts;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.util.EntityUtils;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
import org.elasticsearch.mocksocket.MockHttpServer;
@ -48,7 +52,10 @@ import java.util.Set;
import static org.elasticsearch.client.RestClientTestUtil.getAllStatusCodes;
import static org.elasticsearch.client.RestClientTestUtil.getHttpMethods;
import static org.elasticsearch.client.RestClientTestUtil.randomStatusCode;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/**
@ -66,22 +73,10 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
@BeforeClass
public static void startHttpServer() throws Exception {
String pathPrefixWithoutLeadingSlash;
if (randomBoolean()) {
pathPrefixWithoutLeadingSlash = "testPathPrefix/" + randomAsciiOfLengthBetween(1, 5);
pathPrefix = "/" + pathPrefixWithoutLeadingSlash;
} else {
pathPrefix = pathPrefixWithoutLeadingSlash = "";
}
pathPrefix = randomBoolean() ? "/testPathPrefix/" + randomAsciiOfLengthBetween(1, 5) : "";
httpServer = createHttpServer();
defaultHeaders = RestClientTestUtil.randomHeaders(getRandom(), "Header-default");
RestClientBuilder restClientBuilder = RestClient.builder(
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
restClientBuilder.setPathPrefix((randomBoolean() ? "/" : "") + pathPrefixWithoutLeadingSlash);
}
restClient = restClientBuilder.build();
restClient = createRestClient(false, true);
}
private static HttpServer createHttpServer() throws Exception {
@ -129,6 +124,35 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
}
}
private static RestClient createRestClient(final boolean useAuth, final boolean usePreemptiveAuth) {
// provide the username/password for every request
final BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("user", "pass"));
final RestClientBuilder restClientBuilder = RestClient.builder(
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
if (pathPrefix.length() > 0) {
// sometimes cut off the leading slash
restClientBuilder.setPathPrefix(randomBoolean() ? pathPrefix.substring(1) : pathPrefix);
}
if (useAuth) {
restClientBuilder.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(final HttpAsyncClientBuilder httpClientBuilder) {
if (usePreemptiveAuth == false) {
// disable preemptive auth by ignoring any authcache
httpClientBuilder.disableAuthCaching();
}
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
});
}
return restClientBuilder.build();
}
@AfterClass
public static void stopHttpServers() throws IOException {
restClient.close();
@ -159,7 +183,7 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertHeaders(defaultHeaders, requestHeaders, esResponse.getHeaders(), standardHeaders);
for (final Header responseHeader : esResponse.getHeaders()) {
String name = responseHeader.getName();
@ -189,7 +213,41 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
bodyTest("GET");
}
private void bodyTest(String method) throws IOException {
/**
* Verify that credentials are sent on the first request with preemptive auth enabled (default when provided with credentials).
*/
public void testPreemptiveAuthEnabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
try (final RestClient restClient = createRestClient(true, true)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);
assertThat(response.getHeader("Authorization"), startsWith("Basic"));
}
}
}
/**
* Verify that credentials are <em>not</em> sent on the first request with preemptive auth disabled.
*/
public void testPreemptiveAuthDisabled() throws IOException {
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
try (final RestClient restClient = createRestClient(true, false)) {
for (final String method : methods) {
final Response response = bodyTest(restClient, method);
assertThat(response.getHeader("Authorization"), nullValue());
}
}
}
private Response bodyTest(final String method) throws IOException {
return bodyTest(restClient, method);
}
private Response bodyTest(final RestClient restClient, final String method) throws IOException {
String requestBody = "{ \"field\": \"value\" }";
StringEntity entity = new StringEntity(requestBody);
int statusCode = randomStatusCode(getRandom());
@ -201,7 +259,9 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
}
assertEquals(method, esResponse.getRequestLine().getMethod());
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
assertEquals(requestBody, EntityUtils.toString(esResponse.getEntity()));
return esResponse;
}
}

View File

@ -34,10 +34,12 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpTrace;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine;
@ -96,11 +98,13 @@ public class RestClientSingleHostTests extends RestClientTestCase {
public void createRestClient() throws IOException {
httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
@Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
//return the desired status code or exception depending on the path
if (request.getURI().getPath().equals("/soe")) {
@ -156,7 +160,7 @@ public class RestClientSingleHostTests extends RestClientTestCase {
for (String httpMethod : getHttpMethods()) {
HttpUriRequest expectedRequest = performRandomRequest(httpMethod);
verify(httpClient, times(++times)).<HttpResponse>execute(requestArgumentCaptor.capture(),
any(HttpAsyncResponseConsumer.class), any(FutureCallback.class));
any(HttpAsyncResponseConsumer.class), any(HttpClientContext.class), any(FutureCallback.class));
HttpUriRequest actualRequest = (HttpUriRequest)requestArgumentCaptor.getValue().generateRequest();
assertEquals(expectedRequest.getURI(), actualRequest.getURI());
assertEquals(expectedRequest.getClass(), actualRequest.getClass());

View File

@ -81,6 +81,29 @@ RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200))
.build();
--------------------------------------------------
You can disable Preemptive Authentication, which means that every request will be sent without
authorization headers to see if it is accepted and, upon receiving a HTTP 401 response, it will
resend the exact same request with the basic authentication header. If you wish to do this, then
you can do so by disabling it via the `HttpAsyncClientBuilder`:
[source,java]
--------------------------------------------------
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY,
new UsernamePasswordCredentials("user", "password"));
RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200))
.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
// disable preemptive authentication
httpClientBuilder.disableAuthCaching();
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
})
.build();
--------------------------------------------------
=== Encrypted communication
Encrypted communication can also be configured through the

View File

@ -26,6 +26,7 @@ import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.InputStreamEntity;
@ -430,11 +431,11 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
ContentTooLongException tooLong = new ContentTooLongException("too long!");
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).then(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).then(new Answer<Future<HttpResponse>>() {
@Override
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
HeapBufferedAsyncResponseConsumer consumer = (HeapBufferedAsyncResponseConsumer) invocationOnMock.getArguments()[1];
FutureCallback callback = (FutureCallback) invocationOnMock.getArguments()[2];
FutureCallback callback = (FutureCallback) invocationOnMock.getArguments()[3];
assertEquals(new ByteSizeValue(100, ByteSizeUnit.MB).bytesAsInt(), consumer.getBufferLimit());
callback.failed(tooLong);
return null;
@ -495,7 +496,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
int responseCount = 0;
@ -504,7 +505,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
// Throw away the current thread context to simulate running async httpclient's thread pool
threadPool.getThreadContext().stashContext();
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
HttpEntityEnclosingRequest request = (HttpEntityEnclosingRequest)requestProducer.generateRequest();
URL resource = resources[responseCount];
String path = paths[responseCount++];