HADOOP-17749. Remove lock contention in SelectorPool of SocketIOWithTimeout (#3080)

(cherry picked from commit a5db6831bc)
This commit is contained in:
liangxs 2021-07-06 09:11:03 +08:00 committed by Hui Fei
parent e8f9af6f2a
commit 24b780820c
2 changed files with 124 additions and 58 deletions

View File

@ -28,8 +28,9 @@ import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.hadoop.util.Time;
import org.slf4j.Logger;
@ -48,8 +49,6 @@ abstract class SocketIOWithTimeout {
private long timeout;
private boolean closed = false;
private static SelectorPool selector = new SelectorPool();
/* A timeout value of 0 implies wait for ever.
* We should have a value of timeout that implies zero wait.. i.e.
* read or write returns immediately.
@ -154,7 +153,7 @@ abstract class SocketIOWithTimeout {
//now wait for socket to be ready.
int count = 0;
try {
count = selector.select(channel, ops, timeout);
count = SelectorPool.select(channel, ops, timeout);
} catch (IOException e) { //unexpected IOException.
closed = true;
throw e;
@ -200,7 +199,7 @@ abstract class SocketIOWithTimeout {
// we might have to call finishConnect() more than once
// for some channels (with user level protocols)
int ret = selector.select((SelectableChannel)channel,
int ret = SelectorPool.select(channel,
SelectionKey.OP_CONNECT, timeoutLeft);
if (ret > 0 && channel.finishConnect()) {
@ -242,7 +241,7 @@ abstract class SocketIOWithTimeout {
*/
void waitForIO(int ops) throws IOException {
if (selector.select(channel, ops, timeout) == 0) {
if (SelectorPool.select(channel, ops, timeout) == 0) {
throw new SocketTimeoutException(timeoutExceptionString(channel, timeout,
ops));
}
@ -280,12 +279,17 @@ abstract class SocketIOWithTimeout {
* This maintains a pool of selectors. These selectors are closed
* once they are idle (unused) for a few seconds.
*/
private static class SelectorPool {
private static final class SelectorPool {
private static class SelectorInfo {
Selector selector;
long lastActivityTime;
LinkedList<SelectorInfo> queue;
private static final class SelectorInfo {
private final SelectorProvider provider;
private final Selector selector;
private long lastActivityTime;
private SelectorInfo(SelectorProvider provider, Selector selector) {
this.provider = provider;
this.selector = selector;
}
void close() {
if (selector != null) {
@ -298,16 +302,11 @@ abstract class SocketIOWithTimeout {
}
}
private static class ProviderInfo {
SelectorProvider provider;
LinkedList<SelectorInfo> queue; // lifo
ProviderInfo next;
}
private static ConcurrentHashMap<SelectorProvider, ConcurrentLinkedDeque
<SelectorInfo>> providerMap = new ConcurrentHashMap<>();
private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds.
private ProviderInfo providerList = null;
/**
* Waits on the channel with the given timeout using one of the
* cached selectors. It also removes any cached selectors that are
@ -319,7 +318,7 @@ abstract class SocketIOWithTimeout {
* @return
* @throws IOException
*/
int select(SelectableChannel channel, int ops, long timeout)
static int select(SelectableChannel channel, int ops, long timeout)
throws IOException {
SelectorInfo info = get(channel);
@ -385,35 +384,18 @@ abstract class SocketIOWithTimeout {
* @return
* @throws IOException
*/
private synchronized SelectorInfo get(SelectableChannel channel)
private static SelectorInfo get(SelectableChannel channel)
throws IOException {
SelectorInfo selInfo = null;
SelectorProvider provider = channel.provider();
// pick the list : rarely there is more than one provider in use.
ProviderInfo pList = providerList;
while (pList != null && pList.provider != provider) {
pList = pList.next;
}
if (pList == null) {
//LOG.info("Creating new ProviderInfo : " + provider.toString());
pList = new ProviderInfo();
pList.provider = provider;
pList.queue = new LinkedList<SelectorInfo>();
pList.next = providerList;
providerList = pList;
}
LinkedList<SelectorInfo> queue = pList.queue;
if (queue.isEmpty()) {
ConcurrentLinkedDeque<SelectorInfo> infoQ = providerMap.computeIfAbsent(
provider, k -> new ConcurrentLinkedDeque<>());
SelectorInfo selInfo = infoQ.pollLast(); // last in first out
if (selInfo == null) {
Selector selector = provider.openSelector();
selInfo = new SelectorInfo();
selInfo.selector = selector;
selInfo.queue = queue;
} else {
selInfo = queue.removeLast();
// selInfo will be put into infoQ after `#release()`
selInfo = new SelectorInfo(provider, selector);
}
trimIdleSelectors(Time.now());
@ -426,34 +408,39 @@ abstract class SocketIOWithTimeout {
*
* @param info
*/
private synchronized void release(SelectorInfo info) {
private static void release(SelectorInfo info) {
long now = Time.now();
trimIdleSelectors(now);
info.lastActivityTime = now;
info.queue.addLast(info);
// SelectorInfos in queue are sorted by lastActivityTime
providerMap.get(info.provider).addLast(info);
}
private static AtomicBoolean trimming = new AtomicBoolean(false);
/**
* Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not
* traverse the whole list, just over the one that have crossed
* the timeout.
*/
private void trimIdleSelectors(long now) {
private static void trimIdleSelectors(long now) {
if (!trimming.compareAndSet(false, true)) {
return;
}
long cutoff = now - IDLE_TIMEOUT;
for(ProviderInfo pList=providerList; pList != null; pList=pList.next) {
if (pList.queue.isEmpty()) {
continue;
}
for(Iterator<SelectorInfo> it = pList.queue.iterator(); it.hasNext();) {
SelectorInfo info = it.next();
if (info.lastActivityTime > cutoff) {
for (ConcurrentLinkedDeque<SelectorInfo> infoQ : providerMap.values()) {
SelectorInfo oldest;
while ((oldest = infoQ.peekFirst()) != null) {
if (oldest.lastActivityTime <= cutoff && infoQ.remove(oldest)) {
oldest.close();
} else {
break;
}
it.remove();
info.close();
}
}
trimming.set(false);
}
}
}

View File

@ -24,6 +24,11 @@ import java.io.OutputStream;
import java.net.SocketTimeoutException;
import java.nio.channels.Pipe;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MultithreadedTestUtil;
@ -186,6 +191,46 @@ public class TestSocketIOWithTimeout {
}
}
@Test
public void testSocketIOWithTimeoutByMultiThread() throws Exception {
CountDownLatch latch = new CountDownLatch(1);
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, TIMEOUT);
Pipe.SinkChannel sink = pipe.sink();
OutputStream out = new SocketOutputStream(sink, TIMEOUT)) {
byte[] writeBytes = TEST_STRING.getBytes();
byte[] readBytes = new byte[writeBytes.length];
latch.await();
out.write(writeBytes);
doIO(null, out, TIMEOUT);
in.read(readBytes);
assertArrayEquals(writeBytes, readBytes);
doIO(in, null, TIMEOUT);
}
} catch (Exception e) {
fail(e.getMessage());
}
};
int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
latch.countDown();
threadPool.shutdown();
assertTrue(threadPool.awaitTermination(3, TimeUnit.SECONDS));
}
@Test
public void testSocketIOWithTimeoutInterrupted() throws Exception {
Pipe pipe = Pipe.open();
@ -223,4 +268,38 @@ public class TestSocketIOWithTimeout {
ctx.stop();
}
}
@Test
public void testSocketIOWithTimeoutInterruptedByMultiThread()
throws Exception {
final int timeout = TIMEOUT * 10;
AtomicLong readCount = new AtomicLong();
AtomicLong exceptionCount = new AtomicLong();
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, timeout)) {
in.read();
readCount.incrementAndGet();
} catch (InterruptedIOException ste) {
exceptionCount.incrementAndGet();
}
} catch (Exception e) {
fail(e.getMessage());
}
};
int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
threadPool.shutdownNow();
threadPool.awaitTermination(1, TimeUnit.SECONDS);
assertEquals(0, readCount.get());
assertEquals(threadCnt, exceptionCount.get());
}
}