diff --git a/compute/src/main/java/org/jclouds/compute/util/ConcurrentOpenSocketFinder.java b/compute/src/main/java/org/jclouds/compute/util/ConcurrentOpenSocketFinder.java index 42d7de77fb..377e80d38f 100644 --- a/compute/src/main/java/org/jclouds/compute/util/ConcurrentOpenSocketFinder.java +++ b/compute/src/main/java/org/jclouds/compute/util/ConcurrentOpenSocketFinder.java @@ -18,28 +18,27 @@ */ package org.jclouds.compute.util; +import static java.lang.String.format; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.or; import static com.google.common.base.Throwables.propagate; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.size; import static com.google.common.util.concurrent.Atomics.newReference; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; -import static com.google.common.util.concurrent.MoreExecutors.sameThreadExecutor; +import static org.jclouds.Constants.PROPERTY_USER_THREADS; import static org.jclouds.compute.config.ComputeServiceProperties.TIMEOUT_NODE_RUNNING; -import java.util.Collection; import java.util.NoSuchElementException; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Resource; import javax.inject.Named; -import org.jclouds.Constants; import org.jclouds.compute.domain.NodeMetadata; import org.jclouds.compute.reference.ComputeServiceConstants; import org.jclouds.logging.Logger; @@ -50,13 +49,16 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.Predicate; import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.Inject; -public final class ConcurrentOpenSocketFinder implements OpenSocketFinder { +public class ConcurrentOpenSocketFinder implements OpenSocketFinder { @Resource @Named(ComputeServiceConstants.COMPUTE_LOGGER) @@ -68,123 +70,134 @@ public final class ConcurrentOpenSocketFinder implements OpenSocketFinder { @Inject @VisibleForTesting - ConcurrentOpenSocketFinder(SocketOpen socketTester, - @Named(TIMEOUT_NODE_RUNNING) Predicate> nodeRunning, - @Named(Constants.PROPERTY_USER_THREADS) ExecutorService userThreads) { - this.socketTester =checkNotNull(socketTester, "socketTester"); + ConcurrentOpenSocketFinder(SocketOpen socketTester, + @Named(TIMEOUT_NODE_RUNNING) Predicate> nodeRunning, + @Named(PROPERTY_USER_THREADS) ExecutorService userThreads) { + this.socketTester = checkNotNull(socketTester, "socketTester"); this.nodeRunning = checkNotNull(nodeRunning, "nodeRunning"); this.executor = listeningDecorator(checkNotNull(userThreads, "userThreads")); } - public HostAndPort findOpenSocketOnNode(NodeMetadata node, final int port, - long timeoutValue, TimeUnit timeUnits) { - FluentIterable hosts = checkNodeHasIps(node); - ImmutableSet sockets = hosts.transform(new Function() { + @Override + public HostAndPort findOpenSocketOnNode(NodeMetadata node, final int port, long timeoutValue, TimeUnit timeUnits) { + ImmutableSet sockets = checkNodeHasIps(node).transform(new Function() { @Override public HostAndPort apply(String from) { return HostAndPort.fromParts(from, port); } }).toImmutableSet(); - + // Specify a retry period of 1s, expressed in the same time units. long period = timeUnits.convert(1, TimeUnit.SECONDS); - // For storing the result; needed because predicate will just tell us true/false - final AtomicReference result = newReference(); - final AtomicReference nodeReference = newReference(node); + // For retrieving the socket found (if any) + AtomicReference result = newReference(); - Predicate> concurrentOpenSocketFinder = new Predicate>() { - - @Override - public boolean apply(Collection input) { - HostAndPort reachableSocket = findOpenSocket(input); - if (reachableSocket != null) { - result.set(reachableSocket); - return true; - } else { - if (!nodeRunning.apply(nodeReference)) { - throw new IllegalStateException(String.format("Node %s is no longer running; aborting waiting for ip:port connection", nodeReference.get().getId())); - } - return false; - } - } - - }; - - RetryablePredicate> retryingOpenSocketFinder = new RetryablePredicate>( - concurrentOpenSocketFinder, timeoutValue, period, timeUnits); + Predicate> findOrBreak = or(updateRefOnSocketOpen(result), throwISEIfNoLongerRunning(node)); logger.debug(">> blocking on sockets %s for %d %s", sockets, timeoutValue, timeUnits); + boolean passed = retryPredicate(findOrBreak, period, timeoutValue, timeUnits).apply(sockets); - boolean passed = retryingOpenSocketFinder.apply(sockets); - if (passed) { logger.debug("<< socket %s opened", result); assert result.get() != null; return result.get(); } else { logger.warn("<< sockets %s didn't open after %d %s", sockets, timeoutValue, timeUnits); - throw new NoSuchElementException(String.format("could not connect to any ip address port %d on node %s", - port, node)); + throw new NoSuchElementException(format("could not connect to any ip address port %d on node %s", port, node)); } } /** - * Checks if any any of the given HostAndPorts are reachable. It checks them all concurrently, - * and returns the first one found or null if none are reachable. - * - * @return A reachable HostAndPort, or null. - * @throws InterruptedException + * Checks if any any of the given HostAndPorts are reachable. It checks them + * all concurrently, and sets reference to a {@link HostAndPort} if found or + * returns false; */ - private HostAndPort findOpenSocket(final Collection sockets) { - final AtomicReference result = newReference(); - final CountDownLatch latch = new CountDownLatch(1); - final AtomicInteger completeCount = new AtomicInteger(); - - for (final HostAndPort socket : sockets) { - final ListenableFuture future = executor.submit(new Runnable() { + private Predicate> updateRefOnSocketOpen(final AtomicReference reachableSocket) { + return new Predicate>() { - @Override - public void run() { - try { - if (socketTester.apply(socket)) { - result.compareAndSet(null, socket); - latch.countDown(); + @Override + public boolean apply(Iterable input) { + + Builder> futures = ImmutableList.builder(); + for (final HostAndPort socket : input) { + futures.add(executor.submit(new Runnable() { + + @Override + public void run() { + try { + if (socketTester.apply(socket)) { + // only set if the this socket was found first + reachableSocket.compareAndSet(null, socket); + } + } catch (RuntimeException e) { + logger.warn(e, "Error checking reachability of ip:port %s", socket); + } } - } catch (RuntimeException e) { - logger.warn(e, "Error checking reachability of ip:port %s", socket); - } - } - - }); - - future.addListener(new Runnable() { - @Override - public void run() { - if (completeCount.incrementAndGet() >= sockets.size()) { - latch.countDown(); // Tried all; mark as done - } + })); } - - }, sameThreadExecutor()); - } - - try { - latch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw propagate(e); - } - return result.get(); + blockOn(futures.build()); + return reachableSocket.get() != null; + } + + @Override + public String toString() { + return "setAndReturnTrueIfSocketFound()"; + } + }; } - private FluentIterable checkNodeHasIps(NodeMetadata node) { + /** + * Add this via + * {@code Predicates.or(condition, throwISEIfNoLongerRunning(node))} to + * short-circuit {@link RetryablePredicate} looping when the node is no + * longer running. + */ + private Predicate throwISEIfNoLongerRunning(final NodeMetadata node) { + return new Predicate() { + + @Override + public boolean apply(T input) { + if (!nodeRunning.apply(newReference(node))) { + throw new IllegalStateException(node.getId() + " is no longer running; aborting socket open loop"); + } + return false; + } + + @Override + public String toString() { + return "throwISEIfNoLongerRunning(" + node.getId() + ")"; + } + }; + } + + /** + * @param findOrBreak + * throws {@link IllegalStateException} in order to break the retry + * loop + */ + @VisibleForTesting + Predicate retryPredicate(Predicate findOrBreak, long period, long timeoutValue, TimeUnit timeUnits) { + return new RetryablePredicate(findOrBreak, timeoutValue, period, timeUnits); + } + + private static FluentIterable checkNodeHasIps(NodeMetadata node) { FluentIterable ips = FluentIterable.from(concat(node.getPublicAddresses(), node.getPrivateAddresses())); checkState(size(ips) > 0, "node does not have IP addresses configured: " + node); return ips; } + private static void blockOn(Iterable> immutableList) { + try { + Futures.allAsList(immutableList).get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw propagate(e); + } catch (ExecutionException e) { + throw propagate(e); + } + } } diff --git a/compute/src/test/java/org/jclouds/compute/util/ConcurrentOpenSocketFinderTest.java b/compute/src/test/java/org/jclouds/compute/util/ConcurrentOpenSocketFinderTest.java index e04c1140be..dfdc4f6096 100644 --- a/compute/src/test/java/org/jclouds/compute/util/ConcurrentOpenSocketFinderTest.java +++ b/compute/src/test/java/org/jclouds/compute/util/ConcurrentOpenSocketFinderTest.java @@ -23,10 +23,6 @@ import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.base.Throwables.propagate; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.easymock.EasyMock.createMock; -import static org.easymock.EasyMock.expect; -import static org.easymock.EasyMock.replay; -import static org.easymock.EasyMock.verify; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -36,14 +32,14 @@ import java.util.NoSuchElementException; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import org.easymock.EasyMock; import org.jclouds.compute.domain.NodeMetadata; import org.jclouds.compute.domain.NodeMetadataBuilder; import org.jclouds.predicates.SocketOpen; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import com.google.common.base.Predicate; @@ -61,19 +57,24 @@ public class ConcurrentOpenSocketFinderTest { private final NodeMetadata node = new NodeMetadataBuilder().id("myid").status(NodeMetadata.Status.RUNNING) .publicAddresses(ImmutableSet.of("1.2.3.4")).privateAddresses(ImmutableSet.of("1.2.3.5")).build(); - private final Predicate> alwaysTrue = alwaysTrue(); - private final Predicate> alwaysFalse = alwaysFalse(); + private final SocketOpen socketAlwaysClosed = new SocketOpen() { + @Override + public boolean apply(HostAndPort input) { + return false; + } + }; + + private final Predicate> nodeRunning = alwaysTrue(); + private final Predicate> nodeNotRunning = alwaysFalse(); - private SocketOpen socketTester; private ExecutorService threadPool; - @BeforeMethod + @BeforeClass public void setUp() { - socketTester = createMock(SocketOpen.class); threadPool = Executors.newCachedThreadPool(); } - @AfterMethod(alwaysRun = true) + @AfterClass(alwaysRun = true) public void tearDown() { if (threadPool != null) threadPool.shutdownNow(); @@ -83,11 +84,7 @@ public class ConcurrentOpenSocketFinderTest { public void testRespectsTimeout() throws Exception { final long timeoutMs = 1000; - expect(socketTester.apply(HostAndPort.fromParts("1.2.3.4", 22))).andReturn(false).times(2, Integer.MAX_VALUE); - expect(socketTester.apply(HostAndPort.fromParts("1.2.3.5", 22))).andReturn(false).times(2, Integer.MAX_VALUE); - replay(socketTester); - - OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketTester, alwaysTrue, threadPool); + OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketAlwaysClosed, nodeRunning, threadPool); Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); @@ -101,32 +98,31 @@ public class ConcurrentOpenSocketFinderTest { assertTrue(timetaken >= timeoutMs - EARLY_GRACE && timetaken <= timeoutMs + SLOW_GRACE, "timetaken=" + timetaken); - verify(socketTester); } @Test public void testReturnsReachable() throws Exception { - expect(socketTester.apply(HostAndPort.fromParts("1.2.3.4", 22))).andReturn(false).once(); - expect(socketTester.apply(HostAndPort.fromParts("1.2.3.5", 22))).andReturn(true).once(); - replay(socketTester); + SocketOpen secondSocketOpen = new SocketOpen() { + @Override + public boolean apply(HostAndPort input) { + return HostAndPort.fromParts("1.2.3.5", 22).equals(input); + } + }; - OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketTester, alwaysTrue, threadPool); + OpenSocketFinder finder = new ConcurrentOpenSocketFinder(secondSocketOpen, nodeRunning, threadPool); HostAndPort result = finder.findOpenSocketOnNode(node, 22, 2000, MILLISECONDS); assertEquals(result, HostAndPort.fromParts("1.2.3.5", 22)); - verify(socketTester); } @Test public void testChecksSocketsConcurrently() throws Exception { - // Can't use mock+answer for concurrency tests; EasyMock uses lock in - // ReplayState ControllableSocketOpen socketTester = new ControllableSocketOpen(ImmutableMap.of( HostAndPort.fromParts("1.2.3.4", 22), new SlowCallable(true, 1500), HostAndPort.fromParts("1.2.3.5", 22), new SlowCallable(true, 1000))); - OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketTester, alwaysTrue, threadPool); + OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketTester, nodeRunning, threadPool); HostAndPort result = finder.findOpenSocketOnNode(node, 22, 2000, MILLISECONDS); assertEquals(result, HostAndPort.fromParts("1.2.3.5", 22)); @@ -134,13 +130,25 @@ public class ConcurrentOpenSocketFinderTest { @Test public void testAbortsWhenNodeNotRunning() throws Exception { - expect(socketTester.apply(EasyMock. anyObject())).andReturn(false); - replay(socketTester); - OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketTester, alwaysFalse, threadPool); + OpenSocketFinder finder = new ConcurrentOpenSocketFinder(socketAlwaysClosed, nodeNotRunning, threadPool) { + @Override + protected Predicate retryPredicate(final Predicate findOrBreak, long period, long timeoutValue, + TimeUnit timeUnits) { + return new Predicate() { + @Override + public boolean apply(T input) { + try { + findOrBreak.apply(input); + fail("should have thrown IllegalStateException"); + } catch (IllegalStateException e) { + } + return false; + } + }; + } + }; - Stopwatch stopwatch = new Stopwatch(); - stopwatch.start(); try { finder.findOpenSocketOnNode(node, 22, 2000, MILLISECONDS); fail(); @@ -149,11 +157,6 @@ public class ConcurrentOpenSocketFinderTest { // Note: don't get the "no longer running" message, because // logged+swallowed by RetryablePredicate } - long timetaken = stopwatch.elapsedMillis(); - - assertTrue(timetaken <= SLOW_GRACE, "timetaken=" + timetaken); - - verify(socketTester); } private static class SlowCallable implements Callable {