HBASE-26708 Netty leak detected and OutOfDirectMemoryError due to direct memory buffering with SASL implementation (#4596)

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
Signed-off-by: Andrew Purtell <apurtell@apache.org>
Signed-off-by: Viraj Jasani <vjasani@apache.org>
(cherry picked from commit 816e919e95)
This commit is contained in:
Duo Zhang 2022-07-07 15:55:26 +08:00
parent 1c90b4344d
commit b543da974a
5 changed files with 53 additions and 131 deletions

View File

@ -17,81 +17,32 @@
*/ */
package org.apache.hadoop.hbase.security; package org.apache.hadoop.hbase.security;
import org.apache.hadoop.hbase.exceptions.ConnectionClosedException;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelOutboundHandlerAdapter; import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPromise;
import org.apache.hbase.thirdparty.io.netty.channel.CoalescingBufferQueue;
import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.PromiseCombiner;
/** /**
* wrap messages with Crypto AES. * wrap messages with Crypto AES.
*/ */
@InterfaceAudience.Private @InterfaceAudience.Private
public class CryptoAESWrapHandler extends ChannelOutboundHandlerAdapter { public class CryptoAESWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final CryptoAES cryptoAES; private final CryptoAES cryptoAES;
private CoalescingBufferQueue queue;
public CryptoAESWrapHandler(CryptoAES cryptoAES) { public CryptoAESWrapHandler(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES; this.cryptoAES = cryptoAES;
} }
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
queue = new CoalescingBufferQueue(ctx.channel()); byte[] bytes = new byte[msg.readableBytes()];
} msg.readBytes(bytes);
byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length);
@Override out.ensureWritable(4 + wrapperBytes.length);
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) out.writeInt(wrapperBytes.length);
throws Exception { out.writeBytes(wrapperBytes);
if (msg instanceof ByteBuf) {
queue.add((ByteBuf) msg, promise);
} else {
ctx.write(msg, promise);
}
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
if (queue.isEmpty()) {
return;
}
ByteBuf buf = null;
try {
ChannelPromise promise = ctx.newPromise();
int readableBytes = queue.readableBytes();
buf = queue.remove(readableBytes, promise);
byte[] bytes = new byte[readableBytes];
buf.readBytes(bytes);
byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length);
ChannelPromise lenPromise = ctx.newPromise();
ctx.write(ctx.alloc().buffer(4).writeInt(wrapperBytes.length), lenPromise);
ChannelPromise contentPromise = ctx.newPromise();
ctx.write(Unpooled.wrappedBuffer(wrapperBytes), contentPromise);
PromiseCombiner combiner = new PromiseCombiner();
combiner.addAll(lenPromise, contentPromise);
combiner.finish(promise);
ctx.flush();
} finally {
if (buf != null) {
ReferenceCountUtil.safeRelease(buf);
}
}
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
if (!queue.isEmpty()) {
queue.releaseAndFailAll(new ConnectionClosedException("Connection closed"));
}
ctx.close(promise);
} }
} }

View File

@ -18,80 +18,31 @@
package org.apache.hadoop.hbase.security; package org.apache.hadoop.hbase.security;
import javax.security.sasl.SaslClient; import javax.security.sasl.SaslClient;
import org.apache.hadoop.hbase.exceptions.ConnectionClosedException;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelOutboundHandlerAdapter; import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPromise;
import org.apache.hbase.thirdparty.io.netty.channel.CoalescingBufferQueue;
import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.PromiseCombiner;
/** /**
* wrap sasl messages. * wrap sasl messages.
*/ */
@InterfaceAudience.Private @InterfaceAudience.Private
public class SaslWrapHandler extends ChannelOutboundHandlerAdapter { public class SaslWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final SaslClient saslClient; private final SaslClient saslClient;
private CoalescingBufferQueue queue;
public SaslWrapHandler(SaslClient saslClient) { public SaslWrapHandler(SaslClient saslClient) {
this.saslClient = saslClient; this.saslClient = saslClient;
} }
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
queue = new CoalescingBufferQueue(ctx.channel()); byte[] bytes = new byte[msg.readableBytes()];
} msg.readBytes(bytes);
byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length);
@Override out.ensureWritable(4 + wrapperBytes.length);
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) out.writeInt(wrapperBytes.length);
throws Exception { out.writeBytes(wrapperBytes);
if (msg instanceof ByteBuf) {
queue.add((ByteBuf) msg, promise);
} else {
ctx.write(msg, promise);
}
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
if (queue.isEmpty()) {
return;
}
ByteBuf buf = null;
try {
ChannelPromise promise = ctx.newPromise();
int readableBytes = queue.readableBytes();
buf = queue.remove(readableBytes, promise);
byte[] bytes = new byte[readableBytes];
buf.readBytes(bytes);
byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length);
ChannelPromise lenPromise = ctx.newPromise();
ctx.write(ctx.alloc().buffer(4).writeInt(wrapperBytes.length), lenPromise);
ChannelPromise contentPromise = ctx.newPromise();
ctx.write(Unpooled.wrappedBuffer(wrapperBytes), contentPromise);
PromiseCombiner combiner = new PromiseCombiner();
combiner.addAll(lenPromise, contentPromise);
combiner.finish(promise);
ctx.flush();
} finally {
if (buf != null) {
ReferenceCountUtil.safeRelease(buf);
}
}
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
if (!queue.isEmpty()) {
queue.releaseAndFailAll(new ConnectionClosedException("Connection closed"));
}
ctx.close(promise);
} }
} }

View File

@ -33,6 +33,7 @@ import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescrip
import org.apache.hbase.thirdparty.com.google.protobuf.Message; import org.apache.hbase.thirdparty.com.google.protobuf.Message;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.Channel; import org.apache.hbase.thirdparty.io.netty.channel.Channel;
import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader;
@ -60,12 +61,15 @@ class NettyServerRpcConnection extends ServerRpcConnection {
void process(final ByteBuf buf) throws IOException, InterruptedException { void process(final ByteBuf buf) throws IOException, InterruptedException {
if (connectionHeaderRead) { if (connectionHeaderRead) {
this.callCleanup = buf::release; this.callCleanup = () -> ReferenceCountUtil.safeRelease(buf);
process(new SingleByteBuff(buf.nioBuffer())); process(new SingleByteBuff(buf.nioBuffer()));
} else { } else {
ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes()); ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes());
buf.readBytes(connectionHeader); try {
buf.release(); buf.readBytes(connectionHeader);
} finally {
buf.release();
}
process(connectionHeader); process(connectionHeader);
} }
} }
@ -78,9 +82,7 @@ class NettyServerRpcConnection extends ServerRpcConnection {
try { try {
if (skipInitialSaslHandshake) { if (skipInitialSaslHandshake) {
skipInitialSaslHandshake = false; skipInitialSaslHandshake = false;
if (callCleanup != null) { callCleanupIfNeeded();
callCleanup.run();
}
return; return;
} }
@ -90,9 +92,7 @@ class NettyServerRpcConnection extends ServerRpcConnection {
processOneRpc(buf); processOneRpc(buf);
} }
} catch (Exception e) { } catch (Exception e) {
if (callCleanup != null) { callCleanupIfNeeded();
callCleanup.run();
}
throw e; throw e;
} finally { } finally {
this.callCleanup = null; this.callCleanup = null;
@ -103,7 +103,7 @@ class NettyServerRpcConnection extends ServerRpcConnection {
public synchronized void close() { public synchronized void close() {
disposeSasl(); disposeSasl();
channel.close(); channel.close();
callCleanup = null; callCleanupIfNeeded();
} }
@Override @Override

View File

@ -342,6 +342,8 @@ abstract class ServerRpcConnection implements Closeable {
} else { } else {
plaintextData = saslServer.unwrap(b, 0, b.length); plaintextData = saslServer.unwrap(b, 0, b.length);
} }
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
processUnwrappedData(plaintextData); processUnwrappedData(plaintextData);
} }
} else { } else {
@ -383,6 +385,9 @@ abstract class ServerRpcConnection implements Closeable {
RpcServer.AUDITLOG.warn("{} {}: {}", RpcServer.AUTH_FAILED_FOR, clientIP, RpcServer.AUDITLOG.warn("{} {}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
saslServer.getAttemptingUser()); saslServer.getAttemptingUser());
throw e; throw e;
} finally {
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
} }
if (replyToken != null) { if (replyToken != null) {
if (RpcServer.LOG.isDebugEnabled()) { if (RpcServer.LOG.isDebugEnabled()) {
@ -412,7 +417,9 @@ abstract class ServerRpcConnection implements Closeable {
int count; int count;
if (unwrappedDataLengthBuffer.remaining() > 0) { if (unwrappedDataLengthBuffer.remaining() > 0) {
count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer); count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) return; if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) {
return;
}
} }
if (unwrappedData == null) { if (unwrappedData == null) {
@ -428,7 +435,9 @@ abstract class ServerRpcConnection implements Closeable {
} }
count = this.rpcServer.channelRead(ch, unwrappedData); count = this.rpcServer.channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0) return; if (count <= 0 || unwrappedData.remaining() > 0) {
return;
}
if (unwrappedData.remaining() == 0) { if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear(); unwrappedDataLengthBuffer.clear();
@ -732,6 +741,13 @@ abstract class ServerRpcConnection implements Closeable {
doRespond(getErrorResponse(msg, e)); doRespond(getErrorResponse(msg, e));
} }
protected final void callCleanupIfNeeded() {
if (callCleanup != null) {
callCleanup.run();
callCleanup = null;
}
}
protected final boolean processPreamble(ByteBuffer preambleBuffer) throws IOException { protected final boolean processPreamble(ByteBuffer preambleBuffer) throws IOException {
assert preambleBuffer.remaining() == 6; assert preambleBuffer.remaining() == 6;
for (int i = 0; i < RPC_HEADER.length; i++) { for (int i = 0; i < RPC_HEADER.length; i++) {

View File

@ -284,7 +284,9 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
} else { } else {
processOneRpc(data); processOneRpc(data);
} }
} catch (Exception e) {
callCleanupIfNeeded();
throw e;
} finally { } finally {
dataLengthBuffer.clear(); // Clean for the next call dataLengthBuffer.clear(); // Clean for the next call
data = null; // For the GC data = null; // For the GC
@ -296,8 +298,10 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
public synchronized void close() { public synchronized void close() {
disposeSasl(); disposeSasl();
data = null; data = null;
callCleanup = null; callCleanupIfNeeded();
if (!channel.isOpen()) return; if (!channel.isOpen()) {
return;
}
try { try {
socket.shutdownOutput(); socket.shutdownOutput();
} catch (Exception ignored) { } catch (Exception ignored) {