diff --git a/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java b/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java index e51a97bdaaa..546875fd2c7 100644 --- a/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java +++ b/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java @@ -49,8 +49,12 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor { @Override public void execute(Runnable command) { + if (command instanceof PrioritizedRunnable) { + super.execute(new TieBreakingPrioritizedRunnable((PrioritizedRunnable) command, tieBreaker.incrementAndGet())); + return; + } if (!(command instanceof Comparable)) { - command = PrioritizedRunnable.wrap(command, Priority.NORMAL); + command = new TieBreakingPrioritizedRunnable(command, Priority.NORMAL, tieBreaker.incrementAndGet()); } super.execute(command); } @@ -71,6 +75,36 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor { return new PrioritizedFutureTask((PrioritizedCallable) callable, tieBreaker.incrementAndGet()); } + static class TieBreakingPrioritizedRunnable extends PrioritizedRunnable { + + private final Runnable runnable; + private final long tieBreaker; + + TieBreakingPrioritizedRunnable(PrioritizedRunnable runnable, long tieBreaker) { + this(runnable, runnable.priority(), tieBreaker); + } + + TieBreakingPrioritizedRunnable(Runnable runnable, Priority priority, long tieBreaker) { + super(priority); + this.runnable = runnable; + this.tieBreaker = tieBreaker; + } + + @Override + public void run() { + runnable.run(); + } + + @Override + public int compareTo(PrioritizedRunnable pr) { + int res = super.compareTo(pr); + if (res != 0 || !(pr instanceof TieBreakingPrioritizedRunnable)) { + return res; + } + return tieBreaker < ((TieBreakingPrioritizedRunnable)pr).tieBreaker ? -1 : 1; + } + } + /** * */ diff --git a/src/test/java/org/elasticsearch/test/unit/common/util/concurrent/PrioritizedExecutorsTests.java b/src/test/java/org/elasticsearch/test/unit/common/util/concurrent/PrioritizedExecutorsTests.java index 23892baa8ad..c22920a81bb 100644 --- a/src/test/java/org/elasticsearch/test/unit/common/util/concurrent/PrioritizedExecutorsTests.java +++ b/src/test/java/org/elasticsearch/test/unit/common/util/concurrent/PrioritizedExecutorsTests.java @@ -57,7 +57,7 @@ public class PrioritizedExecutorsTests { } @Test - public void testPrioritizedExecutorWithRunnables() throws Exception { + public void testSubmitPrioritizedExecutorWithRunnables() throws Exception { ExecutorService executor = EsExecutors.newSinglePrioritizingThreadExecutor(Executors.defaultThreadFactory()); List results = new ArrayList(7); CountDownLatch awaitingLatch = new CountDownLatch(1); @@ -84,7 +84,34 @@ public class PrioritizedExecutorsTests { } @Test - public void testPrioritizedExecutorWithCallables() throws Exception { + public void testExecutePrioritizedExecutorWithRunnables() throws Exception { + ExecutorService executor = EsExecutors.newSinglePrioritizingThreadExecutor(Executors.defaultThreadFactory()); + List results = new ArrayList(7); + CountDownLatch awaitingLatch = new CountDownLatch(1); + CountDownLatch finishedLatch = new CountDownLatch(7); + executor.execute(new AwaitingJob(awaitingLatch)); + executor.execute(new Job(6, Priority.LANGUID, results, finishedLatch)); + executor.execute(new Job(4, Priority.LOW, results, finishedLatch)); + executor.execute(new Job(1, Priority.HIGH, results, finishedLatch)); + executor.execute(new Job(5, Priority.LOW, results, finishedLatch)); // will execute after the first LOW (fifo) + executor.execute(new Job(0, Priority.URGENT, results, finishedLatch)); + executor.execute(new Job(3, Priority.NORMAL, results, finishedLatch)); + executor.execute(new Job(2, Priority.HIGH, results, finishedLatch)); // will execute after the first HIGH (fifo) + awaitingLatch.countDown(); + finishedLatch.await(); + + assertThat(results.size(), equalTo(7)); + assertThat(results.get(0), equalTo(0)); + assertThat(results.get(1), equalTo(1)); + assertThat(results.get(2), equalTo(2)); + assertThat(results.get(3), equalTo(3)); + assertThat(results.get(4), equalTo(4)); + assertThat(results.get(5), equalTo(5)); + assertThat(results.get(6), equalTo(6)); + } + + @Test + public void testSubmitPrioritizedExecutorWithCallables() throws Exception { ExecutorService executor = EsExecutors.newSinglePrioritizingThreadExecutor(Executors.defaultThreadFactory()); List results = new ArrayList(7); CountDownLatch awaitingLatch = new CountDownLatch(1); @@ -111,7 +138,7 @@ public class PrioritizedExecutorsTests { } @Test - public void testPrioritizedExecutorWithMixed() throws Exception { + public void testSubmitPrioritizedExecutorWithMixed() throws Exception { ExecutorService executor = EsExecutors.newSinglePrioritizingThreadExecutor(Executors.defaultThreadFactory()); List results = new ArrayList(7); CountDownLatch awaitingLatch = new CountDownLatch(1);