HTTPCLIENT-2350 - Refactored the connect method in DefaultHttpClientConnectionOperator to enhance flexibility in address resolution, specifically allowing for direct handling of unresolved addresses. Updated DnsResolver to introduce a new resolve method supporting both standard and bypassed DNS lookups, enabling improved support for non-public resolvable hosts like .onion endpoints via SOCKS proxy. Adjusted related tests to align with the new resolution mechanism. (#598)

This commit is contained in:
Arturo Bernal 2024-11-16 16:23:57 +01:00 committed by GitHub
parent 4b2a365c36
commit 0b56a628c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 349 additions and 71 deletions

View File

@ -27,7 +27,12 @@
package org.apache.hc.client5.http;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.hc.core5.annotation.Contract;
import org.apache.hc.core5.annotation.ThreadingBehavior;
@ -61,4 +66,20 @@ public interface DnsResolver {
*/
String resolveCanonicalHostname(String host) throws UnknownHostException;
/**
* Returns a list of {@link InetSocketAddress} for the given host with the given port.
*
* @see InetSocketAddress
*
* @since 5.5
*/
default List<InetSocketAddress> resolve(String host, int port) throws UnknownHostException {
final InetAddress[] inetAddresses = resolve(host);
if (inetAddresses == null) {
return Collections.singletonList(InetSocketAddress.createUnresolved(host, port));
}
return Arrays.stream(inetAddresses)
.map(e -> new InetSocketAddress(e, port))
.collect(Collectors.toList());
}
}

View File

@ -27,13 +27,12 @@
package org.apache.hc.client5.http.impl.io;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.net.ssl.SSLSocket;
@ -154,43 +153,37 @@ public class DefaultHttpClientConnectionOperator implements HttpClientConnection
final SocketConfig socketConfig,
final Object attachment,
final HttpContext context) throws IOException {
Args.notNull(conn, "Connection");
Args.notNull(endpointHost, "Host");
Args.notNull(socketConfig, "Socket config");
Args.notNull(context, "Context");
final InetAddress[] remoteAddresses;
if (endpointHost.getAddress() != null) {
remoteAddresses = new InetAddress[] { endpointHost.getAddress() };
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("{} resolving remote address", endpointHost.getHostName());
}
remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName());
if (LOG.isDebugEnabled()) {
LOG.debug("{} resolved to {}", endpointHost.getHostName(), remoteAddresses == null ? "null" : Arrays.asList(remoteAddresses));
}
if (remoteAddresses == null || remoteAddresses.length == 0) {
throw new UnknownHostException(endpointHost.getHostName());
}
}
final Timeout soTimeout = socketConfig.getSoTimeout();
final SocketAddress socksProxyAddress = socketConfig.getSocksProxyAddress();
final Proxy socksProxy = socksProxyAddress != null ? new Proxy(Proxy.Type.SOCKS, socksProxyAddress) : null;
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);
for (int i = 0; i < remoteAddresses.length; i++) {
final InetAddress address = remoteAddresses[i];
final boolean last = i == remoteAddresses.length - 1;
final InetSocketAddress remoteAddress = new InetSocketAddress(address, port);
final List<InetSocketAddress> remoteAddresses;
if (endpointHost.getAddress() != null) {
remoteAddresses = Collections.singletonList(
new InetSocketAddress(endpointHost.getAddress(), this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost)));
} else {
final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost);
remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName(), port);
}
for (int i = 0; i < remoteAddresses.size(); i++) {
final InetSocketAddress remoteAddress = remoteAddresses.get(i);
final boolean last = i == remoteAddresses.size() - 1;
onBeforeSocketConnect(context, endpointHost);
if (LOG.isDebugEnabled()) {
LOG.debug("{} connecting {}->{} ({})", endpointHost, localAddress, remoteAddress, connectTimeout);
}
final Socket socket = detachedSocketFactory.create(socksProxy);
try {
// Always bind to the local address if it's provided.
if (localAddress != null) {
socket.bind(localAddress);
}
conn.bind(socket);
if (soTimeout != null) {
socket.setSoTimeout(soTimeout.toMillisecondsIntBound());
@ -209,16 +202,11 @@ public class DefaultHttpClientConnectionOperator implements HttpClientConnection
if (linger >= 0) {
socket.setSoLinger(true, linger);
}
if (localAddress != null) {
socket.bind(localAddress);
}
socket.connect(remoteAddress, TimeValue.isPositive(connectTimeout) ? connectTimeout.toMillisecondsIntBound() : 0);
conn.bind(socket);
onAfterSocketConnect(context, endpointHost);
if (LOG.isDebugEnabled()) {
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost,
conn.getLocalAddress(), conn.getRemoteAddress());
LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost, conn.getLocalAddress(), conn.getRemoteAddress());
}
conn.setSocketTimeout(soTimeout);
final TlsSocketStrategy tlsSocketStrategy = tlsSocketStrategyLookup != null ? tlsSocketStrategyLookup.lookup(endpointHost.getSchemeName()) : null;
@ -245,7 +233,7 @@ public class DefaultHttpClientConnectionOperator implements HttpClientConnection
if (LOG.isDebugEnabled()) {
LOG.debug("{} connection to {} failed ({}); terminating operation", endpointHost, remoteAddress, ex.getClass());
}
throw ConnectExceptionSupport.enhance(ex, endpointHost, remoteAddresses);
throw ConnectExceptionSupport.enhance(ex, endpointHost);
}
if (LOG.isDebugEnabled()) {
LOG.debug("{} connection to {} failed ({}); retrying connection to the next address", endpointHost, remoteAddress, ex.getClass());

View File

@ -32,7 +32,7 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
@ -108,11 +108,11 @@ final class MultihomeIOSessionRequester {
LOG.debug("{} resolving remote address", remoteEndpoint.getHostName());
}
final InetAddress[] remoteAddresses;
final List<InetSocketAddress> remoteAddresses;
try {
remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName());
if (remoteAddresses == null || remoteAddresses.length == 0) {
throw new UnknownHostException(remoteEndpoint.getHostName());
remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName(), remoteEndpoint.getPort());
if (remoteAddresses == null || remoteAddresses.isEmpty()) {
throw new UnknownHostException(remoteEndpoint.getHostName());
}
} catch (final UnknownHostException ex) {
future.failed(ex);
@ -120,7 +120,7 @@ final class MultihomeIOSessionRequester {
}
if (LOG.isDebugEnabled()) {
LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), Arrays.asList(remoteAddresses));
LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), remoteAddresses);
}
final Runnable runnable = new Runnable() {
@ -129,7 +129,7 @@ final class MultihomeIOSessionRequester {
void executeNext() {
final int index = attempt.getAndIncrement();
final InetSocketAddress remoteAddress = new InetSocketAddress(remoteAddresses[index], remoteEndpoint.getPort());
final InetSocketAddress remoteAddress = remoteAddresses.get(index);
if (LOG.isDebugEnabled()) {
LOG.debug("{}:{} connecting {}->{} ({})",
@ -155,13 +155,17 @@ final class MultihomeIOSessionRequester {
@Override
public void failed(final Exception cause) {
if (attempt.get() >= remoteAddresses.length) {
if (attempt.get() >= remoteAddresses.size()) {
if (LOG.isDebugEnabled()) {
LOG.debug("{}:{} connection to {} failed ({}); terminating operation",
remoteEndpoint.getHostName(), remoteEndpoint.getPort(), remoteAddress, cause.getClass());
}
if (cause instanceof IOException) {
future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, remoteAddresses));
final InetAddress[] addresses = remoteAddresses.stream()
.filter(addr -> addr instanceof InetSocketAddress)
.map(addr -> ((InetSocketAddress) addr).getAddress())
.toArray(InetAddress[]::new);
future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, addresses));
} else {
future.failed(cause);
}

View File

@ -30,6 +30,7 @@ package org.apache.hc.client5.http.impl.io;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocket;
@ -384,7 +385,7 @@ class TestBasicHttpClientConnectionManager {
.build();
mgr.setTlsConfig(tlsConfig);
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] {remote});
Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
@ -398,7 +399,7 @@ class TestBasicHttpClientConnectionManager {
mgr.connect(endpoint1, null, context);
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
@ -406,7 +407,7 @@ class TestBasicHttpClientConnectionManager {
mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
@ -441,7 +442,7 @@ class TestBasicHttpClientConnectionManager {
.build();
mgr.setTlsConfig(tlsConfig);
Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote});
Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080)));
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
@ -449,7 +450,7 @@ class TestBasicHttpClientConnectionManager {
mgr.connect(endpoint1, null, context);
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);

View File

@ -32,6 +32,9 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocket;
@ -88,9 +91,14 @@ class TestHttpClientConnectionOperator {
final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0});
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {127, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {127, 0, 0, 2});
final int port = 80;
final List<InetSocketAddress> resolvedAddresses = Arrays.asList(
new InetSocketAddress(ip1, port),
new InetSocketAddress(ip2, port)
);
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 });
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
final SocketConfig socketConfig = SocketConfig.custom()
@ -110,7 +118,7 @@ class TestHttpClientConnectionOperator {
Mockito.verify(socket).setTcpNoDelay(true);
Mockito.verify(socket).bind(localAddress);
Mockito.verify(socket).connect(new InetSocketAddress(ip1, 80), 123);
Mockito.verify(socket).connect(new InetSocketAddress(ip1, port), 123);
Mockito.verify(conn, Mockito.times(2)).bind(socket);
}
@ -121,14 +129,20 @@ class TestHttpClientConnectionOperator {
final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0});
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {127, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {127, 0, 0, 2});
final int port = 443;
final TlsConfig tlsConfig = TlsConfig.custom()
.setHandshakeTimeout(Timeout.ofMilliseconds(345))
.setVersionPolicy(HttpVersionPolicy.FORCE_HTTP_1)
.build();
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 });
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(443);
final List<InetSocketAddress> resolvedAddresses = Arrays.asList(
new InetSocketAddress(ip1, port),
new InetSocketAddress(ip2, port)
);
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
@ -144,27 +158,32 @@ class TestHttpClientConnectionOperator {
connectionOperator.connect(conn, host, null, localAddress,
Timeout.ofMilliseconds(123), SocketConfig.DEFAULT, tlsConfig, context);
Mockito.verify(socket).connect(new InetSocketAddress(ip1, 443), 123);
Mockito.verify(socket).connect(new InetSocketAddress(ip1, port), 123);
Mockito.verify(conn, Mockito.times(2)).bind(socket);
Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", -1, tlsConfig, context);
Mockito.verify(conn, Mockito.times(1)).bind(upgradedSocket, socket);
}
@Test
void testConnectTimeout() throws Exception {
final HttpClientContext context = HttpClientContext.create();
final HttpHost host = new HttpHost("somehost");
final int port = 80;
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2});
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 });
Mockito.when(schemePortResolver.resolve(host)).thenReturn(80);
final List<InetSocketAddress> resolvedAddresses = Arrays.asList(
new InetSocketAddress(ip1, port),
new InetSocketAddress(ip2, port)
);
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.doThrow(new SocketTimeoutException()).when(socket).connect(Mockito.any(), Mockito.anyInt());
Assertions.assertThrows(ConnectTimeoutException.class, () ->
connectionOperator.connect(
conn, host, null, TimeValue.ofMilliseconds(1000), SocketConfig.DEFAULT, context));
conn, host, null, new InetSocketAddress(InetAddress.getLoopbackAddress(), 0),
Timeout.ofMilliseconds(1000), SocketConfig.DEFAULT, null, context));
}
@Test
@ -173,9 +192,14 @@ class TestHttpClientConnectionOperator {
final HttpHost host = new HttpHost("somehost");
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2});
final int port = 80;
final List<InetSocketAddress> resolvedAddresses = Arrays.asList(
new InetSocketAddress(ip1, port),
new InetSocketAddress(ip2, port)
);
Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses);
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 });
Mockito.when(schemePortResolver.resolve(host)).thenReturn(80);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.any(), Mockito.anyInt());
@ -189,14 +213,14 @@ class TestHttpClientConnectionOperator {
final HttpClientContext context = HttpClientContext.create();
final HttpHost host = new HttpHost("somehost");
final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0});
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2});
final InetSocketAddress ipAddress1 = new InetSocketAddress(InetAddress.getByAddress(new byte[] {10, 0, 0, 1}), 80);
final InetSocketAddress ipAddress2 = new InetSocketAddress(InetAddress.getByAddress(new byte[] {10, 0, 0, 2}), 80);
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 });
Mockito.when(dnsResolver.resolve("somehost", 80)).thenReturn(Arrays.asList(ipAddress1, ipAddress2));
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
Mockito.doThrow(new ConnectException()).when(socket).connect(
Mockito.eq(new InetSocketAddress(ip1, 80)),
Mockito.eq(ipAddress1),
Mockito.anyInt());
final InetSocketAddress localAddress = new InetSocketAddress(local, 0);
@ -206,7 +230,7 @@ class TestHttpClientConnectionOperator {
Timeout.ofMilliseconds(123), SocketConfig.DEFAULT, tlsConfig, context);
Mockito.verify(socket, Mockito.times(2)).bind(localAddress);
Mockito.verify(socket).connect(new InetSocketAddress(ip2, 80), 123);
Mockito.verify(socket).connect(ipAddress2, 123);
Mockito.verify(conn, Mockito.times(3)).bind(socket);
}
@ -229,7 +253,7 @@ class TestHttpClientConnectionOperator {
Mockito.verify(socket).bind(localAddress);
Mockito.verify(socket).connect(new InetSocketAddress(ip, 80), 123);
Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString());
Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString(), Mockito.anyInt());
Mockito.verify(conn, Mockito.times(2)).bind(socket);
}
@ -279,4 +303,82 @@ class TestHttpClientConnectionOperator {
connectionOperator.upgrade(conn, host, context));
}
@Test
void testConnectWithDisableDnsResolution() throws Exception {
final HttpClientContext context = HttpClientContext.create();
final HttpHost host = new HttpHost("someonion.onion");
final InetAddress local = InetAddress.getByAddress(new byte[]{127, 0, 0, 0});
final int port = 80;
final List<InetSocketAddress> resolvedAddresses = Collections.singletonList(
InetSocketAddress.createUnresolved(host.getHostName(), port)
);
Mockito.when(dnsResolver.resolve(host.getHostName(), port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
final SocketConfig socketConfig = SocketConfig.custom()
.setSoKeepAlive(true)
.setSoReuseAddress(true)
.setSoTimeout(5000, TimeUnit.MILLISECONDS)
.setTcpNoDelay(true)
.setSoLinger(50, TimeUnit.MILLISECONDS)
.build();
final InetSocketAddress localAddress = new InetSocketAddress(local, 0);
final InetSocketAddress remoteAddress = InetSocketAddress.createUnresolved(host.getHostName(), port);
connectionOperator.connect(conn, host, null, localAddress, Timeout.ofMilliseconds(123), socketConfig, null, context);
// Verify that the socket was created and attempted to connect without DNS resolution
Mockito.verify(socket).setKeepAlive(true);
Mockito.verify(socket).setReuseAddress(true);
Mockito.verify(socket).setSoTimeout(5000);
Mockito.verify(socket).setSoLinger(true, 50);
Mockito.verify(socket).setTcpNoDelay(true);
Mockito.verify(socket).bind(localAddress);
Mockito.verify(socket).connect(remoteAddress, 123);
Mockito.verify(conn, Mockito.times(2)).bind(socket);
Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString());
}
@Test
void testConnectWithDnsResolutionAndFallback() throws Exception {
final HttpClientContext context = HttpClientContext.create();
final HttpHost host = new HttpHost("fallbackhost.com");
final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0});
final int port = 8080;
final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1});
final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2});
// Update to match the new `resolve` implementation that returns a list of SocketAddress
final List<InetSocketAddress> resolvedAddresses = Arrays.asList(
new InetSocketAddress(ip1, port),
new InetSocketAddress(ip2, port)
);
Mockito.when(dnsResolver.resolve("fallbackhost.com", port)).thenReturn(resolvedAddresses);
Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
// Simulate failure to connect to the first resolved address
Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.eq(new InetSocketAddress(ip1, port)), Mockito.anyInt());
final InetSocketAddress localAddress = new InetSocketAddress(local, 0);
final SocketConfig socketConfig = SocketConfig.custom()
.setSoKeepAlive(true)
.setSoReuseAddress(true)
.setSoTimeout(5000, TimeUnit.MILLISECONDS)
.setTcpNoDelay(true)
.setSoLinger(50, TimeUnit.MILLISECONDS)
.build();
// Connect using the updated connection operator
connectionOperator.connect(conn, host, null, localAddress, Timeout.ofMilliseconds(123), socketConfig, null, context);
// Verify fallback behavior after connection failure to the first address
Mockito.verify(socket, Mockito.times(2)).bind(localAddress);
Mockito.verify(socket).connect(new InetSocketAddress(ip2, port), 123);
Mockito.verify(conn, Mockito.times(3)).bind(socket);
}
}

View File

@ -30,6 +30,7 @@ package org.apache.hc.client5.http.impl.io;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@ -264,7 +265,7 @@ class TestPoolingHttpClientConnectionManager {
.build();
mgr.setDefaultTlsConfig(tlsConfig);
Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[]{remote});
Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443)));
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket);
@ -278,7 +279,7 @@ class TestPoolingHttpClientConnectionManager {
mgr.connect(endpoint1, null, context);
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234);
@ -286,7 +287,7 @@ class TestPoolingHttpClientConnectionManager {
mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context);
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost");
Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443);
Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target);
Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123);
@ -331,7 +332,7 @@ class TestPoolingHttpClientConnectionManager {
.build();
mgr.setDefaultTlsConfig(tlsConfig);
Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote});
Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080)));
Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080);
Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443);
Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy);
@ -339,7 +340,7 @@ class TestPoolingHttpClientConnectionManager {
mgr.connect(endpoint1, null, context);
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy");
Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080);
Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy);
Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null);
Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234);

View File

@ -0,0 +1,161 @@
/*
* ====================================================================
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* ====================================================================
*
* This software consists of voluntary contributions made by many
* individuals on behalf of the Apache Software Foundation. For more
* information on the Apache Software Foundation, please see
* <http://www.apache.org/>.
*
*/
package org.apache.hc.client5.http.impl.nio;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.apache.hc.client5.http.DnsResolver;
import org.apache.hc.core5.concurrent.FutureCallback;
import org.apache.hc.core5.net.NamedEndpoint;
import org.apache.hc.core5.reactor.ConnectionInitiator;
import org.apache.hc.core5.reactor.IOSession;
import org.apache.hc.core5.util.Timeout;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
class MultihomeIOSessionRequesterTest {
private DnsResolver dnsResolver;
private ConnectionInitiator connectionInitiator;
private MultihomeIOSessionRequester sessionRequester;
private NamedEndpoint namedEndpoint;
@BeforeEach
void setUp() {
dnsResolver = Mockito.mock(DnsResolver.class);
connectionInitiator = Mockito.mock(ConnectionInitiator.class);
namedEndpoint = Mockito.mock(NamedEndpoint.class);
sessionRequester = new MultihomeIOSessionRequester(dnsResolver);
}
@Test
void testConnectWithMultipleAddresses() throws Exception {
final InetAddress address1 = InetAddress.getByAddress(new byte[]{10, 0, 0, 1});
final InetAddress address2 = InetAddress.getByAddress(new byte[]{10, 0, 0, 2});
final List<InetSocketAddress> remoteAddresses = Arrays.asList(
new InetSocketAddress(address1, 8080),
new InetSocketAddress(address2, 8080)
);
Mockito.when(namedEndpoint.getHostName()).thenReturn("somehost");
Mockito.when(namedEndpoint.getPort()).thenReturn(8080);
Mockito.when(dnsResolver.resolve("somehost", 8080)).thenReturn(remoteAddresses);
Mockito.when(connectionInitiator.connect(any(), any(), any(), any(), any(), any()))
.thenAnswer(invocation -> {
final FutureCallback<IOSession> callback = invocation.getArgument(5);
// Simulate a failure for the first connection attempt
final CompletableFuture<IOSession> future = new CompletableFuture<>();
callback.failed(new IOException("Simulated connection failure"));
future.completeExceptionally(new IOException("Simulated connection failure"));
return future;
});
final Future<IOSession> future = sessionRequester.connect(
connectionInitiator,
namedEndpoint,
null,
Timeout.ofMilliseconds(500),
null,
null
);
assertTrue(future.isDone());
try {
future.get();
fail("Expected ExecutionException");
} catch (final ExecutionException ex) {
assertInstanceOf(IOException.class, ex.getCause());
assertEquals("Simulated connection failure", ex.getCause().getMessage());
}
}
@Test
void testConnectSuccessfulAfterRetries() throws Exception {
final InetAddress address1 = InetAddress.getByAddress(new byte[]{10, 0, 0, 1});
final InetAddress address2 = InetAddress.getByAddress(new byte[]{10, 0, 0, 2});
final List<InetSocketAddress> remoteAddresses = Arrays.asList(
new InetSocketAddress(address1, 8080),
new InetSocketAddress(address2, 8080)
);
Mockito.when(namedEndpoint.getHostName()).thenReturn("somehost");
Mockito.when(namedEndpoint.getPort()).thenReturn(8080);
Mockito.when(dnsResolver.resolve("somehost", 8080)).thenReturn(remoteAddresses);
Mockito.when(connectionInitiator.connect(any(), any(), any(), any(), any(), any()))
.thenAnswer(invocation -> {
final FutureCallback<IOSession> callback = invocation.getArgument(5);
final InetSocketAddress remoteAddress = invocation.getArgument(1);
final CompletableFuture<IOSession> future = new CompletableFuture<>();
if (remoteAddress.getAddress().equals(address1)) {
// Fail the first address
callback.failed(new IOException("Simulated connection failure"));
future.completeExceptionally(new IOException("Simulated connection failure"));
} else {
// Succeed for the second address
final IOSession mockSession = Mockito.mock(IOSession.class);
callback.completed(mockSession);
future.complete(mockSession);
}
return future;
});
final Future<IOSession> future = sessionRequester.connect(
connectionInitiator,
namedEndpoint,
null,
Timeout.ofMilliseconds(500),
null,
null
);
assertTrue(future.isDone());
try {
final IOSession session = future.get();
assertNotNull(session);
} catch (final ExecutionException ex) {
fail("Did not expect an ExecutionException", ex);
}
}
}