diff --git a/src/main/java/org/elasticsearch/action/support/AbstractListenableActionFuture.java b/src/main/java/org/elasticsearch/action/support/AbstractListenableActionFuture.java index 755ed4f67eb..7d4497d4bd6 100644 --- a/src/main/java/org/elasticsearch/action/support/AbstractListenableActionFuture.java +++ b/src/main/java/org/elasticsearch/action/support/AbstractListenableActionFuture.java @@ -99,7 +99,9 @@ public abstract class AbstractListenableActionFuture extends AdapterAction private void executeListener(final ActionListener listener) { try { - listener.onResponse(actionGet()); + // we use a timeout of 0 to by pass assertion forbidding to call actionGet() (blocking) on a network thread. + // here we know we will never block + listener.onResponse(actionGet(0)); } catch (Throwable e) { listener.onFailure(e); } diff --git a/src/main/java/org/elasticsearch/common/util/concurrent/BaseFuture.java b/src/main/java/org/elasticsearch/common/util/concurrent/BaseFuture.java index 71e07fbadfd..2ef8e1901f9 100644 --- a/src/main/java/org/elasticsearch/common/util/concurrent/BaseFuture.java +++ b/src/main/java/org/elasticsearch/common/util/concurrent/BaseFuture.java @@ -20,7 +20,6 @@ package org.elasticsearch.common.util.concurrent; import com.google.common.annotations.Beta; - import org.elasticsearch.common.Nullable; import org.elasticsearch.transport.Transports; @@ -92,7 +91,7 @@ public abstract class BaseFuture implements Future { @Override public V get(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException, ExecutionException { - Transports.assertNotTransportThread("Blocking operation"); + assert timeout <= 0 || Transports.assertNotTransportThread("Blocking operation"); return sync.get(unit.toNanos(timeout)); } @@ -114,7 +113,7 @@ public abstract class BaseFuture implements Future { */ @Override public V get() throws InterruptedException, ExecutionException { - Transports.assertNotTransportThread("Blocking operation"); + assert Transports.assertNotTransportThread("Blocking operation"); return sync.get(); } diff --git a/src/main/java/org/elasticsearch/transport/Transports.java b/src/main/java/org/elasticsearch/transport/Transports.java index 25f8c52723f..68d828fc72f 100644 --- a/src/main/java/org/elasticsearch/transport/Transports.java +++ b/src/main/java/org/elasticsearch/transport/Transports.java @@ -27,6 +27,9 @@ import java.util.Arrays; public enum Transports { ; + /** threads whose name is prefixed by this string will be considered network threads, even though they aren't */ + public final static String TEST_MOCK_TRANSPORT_THREAD_PREFIX = "__mock_network_thread"; + /** * Utility method to detect whether a thread is a network thread. Typically * used in assertions to make sure that we do not call blocking code from @@ -39,7 +42,8 @@ public enum Transports { NettyTransport.HTTP_SERVER_BOSS_THREAD_NAME_PREFIX, NettyTransport.HTTP_SERVER_WORKER_THREAD_NAME_PREFIX, NettyTransport.TRANSPORT_CLIENT_WORKER_THREAD_NAME_PREFIX, - NettyTransport.TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX)) { + NettyTransport.TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX, + TEST_MOCK_TRANSPORT_THREAD_PREFIX)) { if (threadName.contains(s)) { return true; } @@ -47,13 +51,15 @@ public enum Transports { return false; } - public static void assertTransportThread() { + public static boolean assertTransportThread() { final Thread t = Thread.currentThread(); assert isTransportThread(t) : "Expected transport thread but got [" + t + "]"; + return true; } - public static void assertNotTransportThread(String reason) { + public static boolean assertNotTransportThread(String reason) { final Thread t = Thread.currentThread(); - assert isTransportThread(t) ==false : "Expected current thread [" + t + "] to not be a transport thread. Reason: "; + assert isTransportThread(t) == false : "Expected current thread [" + t + "] to not be a transport thread. Reason: [" + reason + "]"; + return true; } } diff --git a/src/test/java/org/elasticsearch/action/support/ListenableActionFutureTests.java b/src/test/java/org/elasticsearch/action/support/ListenableActionFutureTests.java new file mode 100644 index 00000000000..2c97caf6e8c --- /dev/null +++ b/src/test/java/org/elasticsearch/action/support/ListenableActionFutureTests.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.action.support; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.test.ElasticsearchTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transports; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class ListenableActionFutureTests extends ElasticsearchTestCase { + + public void testListenerIsCallableFromNetworkThreads() throws Throwable { + ThreadPool threadPool = new ThreadPool("testListenerIsCallableFromNetworkThreads"); + try { + final PlainListenableActionFuture future = new PlainListenableActionFuture<>(threadPool); + final CountDownLatch listenerCalled = new CountDownLatch(1); + final AtomicReference error = new AtomicReference<>(); + final Object response = new Object(); + future.addListener(new ActionListener() { + @Override + public void onResponse(Object o) { + listenerCalled.countDown(); + } + + @Override + public void onFailure(Throwable e) { + error.set(e); + listenerCalled.countDown(); + } + }); + Thread networkThread = new Thread(new AbstractRunnable() { + @Override + public void onFailure(Throwable t) { + error.set(t); + listenerCalled.countDown(); + } + + @Override + protected void doRun() throws Exception { + future.onResponse(response); + } + }, Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX + "_testListenerIsCallableFromNetworkThread"); + networkThread.start(); + networkThread.join(); + listenerCalled.await(); + if (error.get() != null) { + throw error.get(); + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + + +}