diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 390eaaa18eb..eaf9a3bae1b 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -37,13 +37,16 @@ import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; @@ -553,31 +556,40 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } public void testTimeoutSendExceptionWithDelayedResponse() throws Exception { - CountDownLatch doneLatch = new CountDownLatch(1); - CountDownLatch allResponded = new CountDownLatch(1); + CountDownLatch waitForever = new CountDownLatch(1); + CountDownLatch doneWaitingForever = new CountDownLatch(1); + AtomicInteger inFlight = new AtomicInteger(0); serviceA.registerRequestHandler("sayHelloTimeoutDelayedResponse", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler() { @Override - public void messageReceived(StringMessageRequest request, TransportChannel channel) { - TimeValue sleep = TimeValue.parseTimeValue(request.message, null, "sleep"); + public void messageReceived(StringMessageRequest request, TransportChannel channel) throws InterruptedException { + inFlight.incrementAndGet(); try { - doneLatch.await(sleep.millis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - // ignore - } - try { - channel.sendResponse(new StringMessageResponse("hello " + request.message)); - } catch (IOException e) { - logger.error("Unexpected failure", e); - fail(e.getMessage()); + String message = request.message; + if ("forever".equals(message)) { + waitForever.await(); + } else { + TimeValue sleep = TimeValue.parseTimeValue(message, null, "sleep"); + Thread.sleep(sleep.millis()); + } + try { + channel.sendResponse(new StringMessageResponse("hello " + request.message)); + } catch (IOException e) { + logger.error("Unexpected failure", e); + fail(e.getMessage()); + } finally { + if ("forever".equals(message)) { + doneWaitingForever.countDown(); + } + } } finally { - allResponded.countDown(); + inFlight.decrementAndGet(); } } }); final CountDownLatch latch = new CountDownLatch(1); TransportFuture res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse", - new StringMessageRequest("2m"), TransportRequestOptions.builder().withTimeout(100).build(), + new StringMessageRequest("forever"), TransportRequestOptions.builder().withTimeout(100).build(), new TransportResponseHandler() { @Override public StringMessageResponse newInstance() { @@ -603,17 +615,18 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { }); try { - StringMessageResponse message = res.txGet(); + res.txGet(); fail("exception should be thrown"); } catch (Exception e) { assertThat(e, instanceOf(ReceiveTimeoutTransportException.class)); } latch.await(); + List assertions = new ArrayList<>(); for (int i = 0; i < 10; i++) { final int counter = i; // now, try and send another request, this times, with a short timeout - res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse", + TransportFuture result = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse", new StringMessageRequest(counter + "ms"), TransportRequestOptions.builder().withTimeout(3000).build(), new TransportResponseHandler() { @Override @@ -638,13 +651,18 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { } }); - StringMessageResponse message = res.txGet(); - assertThat(message.message, equalTo("hello " + counter + "ms")); + assertions.add(() -> { + StringMessageResponse message = result.txGet(); + assertThat(message.message, equalTo("hello " + counter + "ms")); + }); + } + for (Runnable runnable : assertions) { + runnable.run(); } - serviceA.removeHandler("sayHelloTimeoutDelayedResponse"); - doneLatch.countDown(); - allResponded.await(); + waitForever.countDown(); + doneWaitingForever.await(); + assertEquals(0, inFlight.get()); } @TestLogging(value = "test. transport.tracer:TRACE")