diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/util/BoundedConcurrentLinkedQueue.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/util/BoundedConcurrentLinkedQueue.java index 92082388890..f66771bfa56 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/util/BoundedConcurrentLinkedQueue.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/util/BoundedConcurrentLinkedQueue.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hbase.util; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.hbase.classification.InterfaceStability; @@ -30,7 +31,7 @@ import org.apache.hadoop.hbase.classification.InterfaceStability; @InterfaceStability.Stable public class BoundedConcurrentLinkedQueue extends ConcurrentLinkedQueue { private static final long serialVersionUID = 1L; - private volatile long size = 0; + private final AtomicLong size = new AtomicLong(0L); private final long maxSize; public BoundedConcurrentLinkedQueue() { @@ -42,41 +43,50 @@ public class BoundedConcurrentLinkedQueue extends ConcurrentLinkedQueue { this.maxSize = maxSize; } - @Override - public boolean add(T e) { - return offer(e); - } - @Override public boolean addAll(Collection c) { - size += c.size(); // Between here and below we might reject offers, - if (size > maxSize) { // if over maxSize, but that's ok - size -= c.size(); // We're over, just back out and return. - return false; + for (;;) { + long currentSize = size.get(); + long nextSize = currentSize + c.size(); + if (nextSize > maxSize) { // already exceeded limit + return false; + } + if (size.compareAndSet(currentSize, nextSize)) { + break; + } } - return super.addAll(c); // Always true for ConcurrentLinkedQueue + return super.addAll(c); // Always true for ConcurrentLinkedQueue } @Override public void clear() { - super.clear(); - size = 0; + // override this method to batch update size. + long removed = 0L; + while (super.poll() != null) { + removed++; + } + size.addAndGet(-removed); } @Override public boolean offer(T e) { - if (++size > maxSize) { - --size; // We didn't take it after all - return false; + for (;;) { + long currentSize = size.get(); + if (currentSize >= maxSize) { // already exceeded limit + return false; + } + if (size.compareAndSet(currentSize, currentSize + 1)) { + break; + } } - return super.offer(e); // Always true for ConcurrentLinkedQueue + return super.offer(e); // Always true for ConcurrentLinkedQueue } @Override public T poll() { T result = super.poll(); if (result != null) { - --size; + size.decrementAndGet(); } return result; } @@ -85,30 +95,28 @@ public class BoundedConcurrentLinkedQueue extends ConcurrentLinkedQueue { public boolean remove(Object o) { boolean result = super.remove(o); if (result) { - --size; + size.decrementAndGet(); } return result; } @Override public int size() { - return (int) size; + return (int) size.get(); } public void drainTo(Collection list) { long removed = 0; - T l; - while ((l = super.poll()) != null) { - list.add(l); + for (T element; (element = super.poll()) != null;) { + list.add(element); removed++; } - // Limit the number of operations on a volatile by only reporting size - // change after the drain is completed. - size -= removed; + // Limit the number of operations on size by only reporting size change after the drain is + // completed. + size.addAndGet(-removed); } public long remainingCapacity() { - long remaining = maxSize - size; - return remaining >= 0 ? remaining : 0; + return maxSize - size.get(); } } \ No newline at end of file diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/util/TestBoundedConcurrentLinkedQueue.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/util/TestBoundedConcurrentLinkedQueue.java index 5972a8764f9..d3c206f681a 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/util/TestBoundedConcurrentLinkedQueue.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/util/TestBoundedConcurrentLinkedQueue.java @@ -18,14 +18,15 @@ package org.apache.hadoop.hbase.util; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.hbase.testclassification.SmallTests; -import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -41,10 +42,6 @@ public class TestBoundedConcurrentLinkedQueue { this.queue = new BoundedConcurrentLinkedQueue(CAPACITY); } - @After - public void tearDown() throws Exception { - } - @Test public void testOfferAndPoll() throws Exception { // Offer @@ -82,4 +79,82 @@ public class TestBoundedConcurrentLinkedQueue { assertEquals(0, queue.size()); assertEquals(CAPACITY, queue.remainingCapacity()); } + + @Test + public void testClear() { + // Offer + for (long i = 1; i <= CAPACITY; ++i) { + assertTrue(queue.offer(i)); + assertEquals(i, queue.size()); + assertEquals(CAPACITY - i, queue.remainingCapacity()); + } + assertFalse(queue.offer(0L)); + + queue.clear(); + assertEquals(null, queue.poll()); + assertEquals(0, queue.size()); + assertEquals(CAPACITY, queue.remainingCapacity()); + } + + @Test + public void testMultiThread() throws InterruptedException { + int offerThreadCount = 10; + int pollThreadCount = 5; + int duration = 5000; // ms + final AtomicBoolean stop = new AtomicBoolean(false); + Thread[] offerThreads = new Thread[offerThreadCount]; + for (int i = 0; i < offerThreadCount; i++) { + offerThreads[i] = new Thread("offer-thread-" + i) { + + @Override + public void run() { + Random rand = new Random(); + while (!stop.get()) { + queue.offer(rand.nextLong()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + } + } + } + + }; + } + Thread[] pollThreads = new Thread[pollThreadCount]; + for (int i = 0; i < pollThreadCount; i++) { + pollThreads[i] = new Thread("poll-thread-" + i) { + + @Override + public void run() { + while (!stop.get()) { + queue.poll(); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + } + } + } + + }; + } + for (Thread t : offerThreads) { + t.start(); + } + for (Thread t : pollThreads) { + t.start(); + } + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < duration) { + assertTrue(queue.size() <= CAPACITY); + Thread.yield(); + } + stop.set(true); + for (Thread t : offerThreads) { + t.join(); + } + for (Thread t : pollThreads) { + t.join(); + } + assertTrue(queue.size() <= CAPACITY); + } }