HADOOP-17749. Remove lock contention in SelectorPool of SocketIOWithTimeout (#3080)

(cherry picked from commit a5db6831bc)
This commit is contained in:
liangxs 2021-07-06 09:11:03 +08:00 committed by Hui Fei
parent e8f9af6f2a
commit 24b780820c
2 changed files with 124 additions and 58 deletions

View File

@ -28,8 +28,9 @@ import java.nio.channels.SelectionKey;
import java.nio.channels.Selector; import java.nio.channels.Selector;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider; import java.nio.channels.spi.SelectorProvider;
import java.util.Iterator; import java.util.concurrent.ConcurrentHashMap;
import java.util.LinkedList; import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.hadoop.util.Time; import org.apache.hadoop.util.Time;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -48,8 +49,6 @@ abstract class SocketIOWithTimeout {
private long timeout; private long timeout;
private boolean closed = false; private boolean closed = false;
private static SelectorPool selector = new SelectorPool();
/* A timeout value of 0 implies wait for ever. /* A timeout value of 0 implies wait for ever.
* We should have a value of timeout that implies zero wait.. i.e. * We should have a value of timeout that implies zero wait.. i.e.
* read or write returns immediately. * read or write returns immediately.
@ -154,7 +153,7 @@ abstract class SocketIOWithTimeout {
//now wait for socket to be ready. //now wait for socket to be ready.
int count = 0; int count = 0;
try { try {
count = selector.select(channel, ops, timeout); count = SelectorPool.select(channel, ops, timeout);
} catch (IOException e) { //unexpected IOException. } catch (IOException e) { //unexpected IOException.
closed = true; closed = true;
throw e; throw e;
@ -200,7 +199,7 @@ abstract class SocketIOWithTimeout {
// we might have to call finishConnect() more than once // we might have to call finishConnect() more than once
// for some channels (with user level protocols) // for some channels (with user level protocols)
int ret = selector.select((SelectableChannel)channel, int ret = SelectorPool.select(channel,
SelectionKey.OP_CONNECT, timeoutLeft); SelectionKey.OP_CONNECT, timeoutLeft);
if (ret > 0 && channel.finishConnect()) { if (ret > 0 && channel.finishConnect()) {
@ -242,7 +241,7 @@ abstract class SocketIOWithTimeout {
*/ */
void waitForIO(int ops) throws IOException { void waitForIO(int ops) throws IOException {
if (selector.select(channel, ops, timeout) == 0) { if (SelectorPool.select(channel, ops, timeout) == 0) {
throw new SocketTimeoutException(timeoutExceptionString(channel, timeout, throw new SocketTimeoutException(timeoutExceptionString(channel, timeout,
ops)); ops));
} }
@ -280,12 +279,17 @@ abstract class SocketIOWithTimeout {
* This maintains a pool of selectors. These selectors are closed * This maintains a pool of selectors. These selectors are closed
* once they are idle (unused) for a few seconds. * once they are idle (unused) for a few seconds.
*/ */
private static class SelectorPool { private static final class SelectorPool {
private static class SelectorInfo { private static final class SelectorInfo {
Selector selector; private final SelectorProvider provider;
long lastActivityTime; private final Selector selector;
LinkedList<SelectorInfo> queue; private long lastActivityTime;
private SelectorInfo(SelectorProvider provider, Selector selector) {
this.provider = provider;
this.selector = selector;
}
void close() { void close() {
if (selector != null) { if (selector != null) {
@ -298,16 +302,11 @@ abstract class SocketIOWithTimeout {
} }
} }
private static class ProviderInfo { private static ConcurrentHashMap<SelectorProvider, ConcurrentLinkedDeque
SelectorProvider provider; <SelectorInfo>> providerMap = new ConcurrentHashMap<>();
LinkedList<SelectorInfo> queue; // lifo
ProviderInfo next;
}
private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds. private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds.
private ProviderInfo providerList = null;
/** /**
* Waits on the channel with the given timeout using one of the * Waits on the channel with the given timeout using one of the
* cached selectors. It also removes any cached selectors that are * cached selectors. It also removes any cached selectors that are
@ -319,7 +318,7 @@ abstract class SocketIOWithTimeout {
* @return * @return
* @throws IOException * @throws IOException
*/ */
int select(SelectableChannel channel, int ops, long timeout) static int select(SelectableChannel channel, int ops, long timeout)
throws IOException { throws IOException {
SelectorInfo info = get(channel); SelectorInfo info = get(channel);
@ -385,35 +384,18 @@ abstract class SocketIOWithTimeout {
* @return * @return
* @throws IOException * @throws IOException
*/ */
private synchronized SelectorInfo get(SelectableChannel channel) private static SelectorInfo get(SelectableChannel channel)
throws IOException { throws IOException {
SelectorInfo selInfo = null;
SelectorProvider provider = channel.provider(); SelectorProvider provider = channel.provider();
// pick the list : rarely there is more than one provider in use. // pick the list : rarely there is more than one provider in use.
ProviderInfo pList = providerList; ConcurrentLinkedDeque<SelectorInfo> infoQ = providerMap.computeIfAbsent(
while (pList != null && pList.provider != provider) { provider, k -> new ConcurrentLinkedDeque<>());
pList = pList.next;
}
if (pList == null) {
//LOG.info("Creating new ProviderInfo : " + provider.toString());
pList = new ProviderInfo();
pList.provider = provider;
pList.queue = new LinkedList<SelectorInfo>();
pList.next = providerList;
providerList = pList;
}
LinkedList<SelectorInfo> queue = pList.queue; SelectorInfo selInfo = infoQ.pollLast(); // last in first out
if (selInfo == null) {
if (queue.isEmpty()) {
Selector selector = provider.openSelector(); Selector selector = provider.openSelector();
selInfo = new SelectorInfo(); // selInfo will be put into infoQ after `#release()`
selInfo.selector = selector; selInfo = new SelectorInfo(provider, selector);
selInfo.queue = queue;
} else {
selInfo = queue.removeLast();
} }
trimIdleSelectors(Time.now()); trimIdleSelectors(Time.now());
@ -426,34 +408,39 @@ abstract class SocketIOWithTimeout {
* *
* @param info * @param info
*/ */
private synchronized void release(SelectorInfo info) { private static void release(SelectorInfo info) {
long now = Time.now(); long now = Time.now();
trimIdleSelectors(now); trimIdleSelectors(now);
info.lastActivityTime = now; info.lastActivityTime = now;
info.queue.addLast(info); // SelectorInfos in queue are sorted by lastActivityTime
providerMap.get(info.provider).addLast(info);
} }
private static AtomicBoolean trimming = new AtomicBoolean(false);
/** /**
* Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not * Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not
* traverse the whole list, just over the one that have crossed * traverse the whole list, just over the one that have crossed
* the timeout. * the timeout.
*/ */
private void trimIdleSelectors(long now) { private static void trimIdleSelectors(long now) {
long cutoff = now - IDLE_TIMEOUT; if (!trimming.compareAndSet(false, true)) {
return;
for(ProviderInfo pList=providerList; pList != null; pList=pList.next) {
if (pList.queue.isEmpty()) {
continue;
} }
for(Iterator<SelectorInfo> it = pList.queue.iterator(); it.hasNext();) {
SelectorInfo info = it.next(); long cutoff = now - IDLE_TIMEOUT;
if (info.lastActivityTime > cutoff) { for (ConcurrentLinkedDeque<SelectorInfo> infoQ : providerMap.values()) {
SelectorInfo oldest;
while ((oldest = infoQ.peekFirst()) != null) {
if (oldest.lastActivityTime <= cutoff && infoQ.remove(oldest)) {
oldest.close();
} else {
break; break;
} }
it.remove();
info.close();
} }
} }
trimming.set(false);
} }
} }
} }

View File

@ -24,6 +24,11 @@ import java.io.OutputStream;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import java.nio.channels.Pipe; import java.nio.channels.Pipe;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.hadoop.test.GenericTestUtils; import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MultithreadedTestUtil; import org.apache.hadoop.test.MultithreadedTestUtil;
@ -186,6 +191,46 @@ public class TestSocketIOWithTimeout {
} }
} }
@Test
public void testSocketIOWithTimeoutByMultiThread() throws Exception {
CountDownLatch latch = new CountDownLatch(1);
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, TIMEOUT);
Pipe.SinkChannel sink = pipe.sink();
OutputStream out = new SocketOutputStream(sink, TIMEOUT)) {
byte[] writeBytes = TEST_STRING.getBytes();
byte[] readBytes = new byte[writeBytes.length];
latch.await();
out.write(writeBytes);
doIO(null, out, TIMEOUT);
in.read(readBytes);
assertArrayEquals(writeBytes, readBytes);
doIO(in, null, TIMEOUT);
}
} catch (Exception e) {
fail(e.getMessage());
}
};
int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
latch.countDown();
threadPool.shutdown();
assertTrue(threadPool.awaitTermination(3, TimeUnit.SECONDS));
}
@Test @Test
public void testSocketIOWithTimeoutInterrupted() throws Exception { public void testSocketIOWithTimeoutInterrupted() throws Exception {
Pipe pipe = Pipe.open(); Pipe pipe = Pipe.open();
@ -223,4 +268,38 @@ public class TestSocketIOWithTimeout {
ctx.stop(); ctx.stop();
} }
} }
@Test
public void testSocketIOWithTimeoutInterruptedByMultiThread()
throws Exception {
final int timeout = TIMEOUT * 10;
AtomicLong readCount = new AtomicLong();
AtomicLong exceptionCount = new AtomicLong();
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, timeout)) {
in.read();
readCount.incrementAndGet();
} catch (InterruptedIOException ste) {
exceptionCount.incrementAndGet();
}
} catch (Exception e) {
fail(e.getMessage());
}
};
int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
threadPool.shutdownNow();
threadPool.awaitTermination(1, TimeUnit.SECONDS);
assertEquals(0, readCount.get());
assertEquals(threadCnt, exceptionCount.get());
}
} }