diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java index c4c914a04d8..a99d097ff2d 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java @@ -17,81 +17,32 @@ */ package org.apache.hadoop.hbase.security; -import org.apache.hadoop.hbase.exceptions.ConnectionClosedException; import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; -import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelOutboundHandlerAdapter; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelPromise; -import org.apache.hbase.thirdparty.io.netty.channel.CoalescingBufferQueue; -import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil; -import org.apache.hbase.thirdparty.io.netty.util.concurrent.PromiseCombiner; +import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder; /** * wrap messages with Crypto AES. */ @InterfaceAudience.Private -public class CryptoAESWrapHandler extends ChannelOutboundHandlerAdapter { +public class CryptoAESWrapHandler extends MessageToByteEncoder { private final CryptoAES cryptoAES; - private CoalescingBufferQueue queue; - public CryptoAESWrapHandler(CryptoAES cryptoAES) { this.cryptoAES = cryptoAES; } @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - queue = new CoalescingBufferQueue(ctx.channel()); - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - if (msg instanceof ByteBuf) { - queue.add((ByteBuf) msg, promise); - } else { - ctx.write(msg, promise); - } - } - - @Override - public void flush(ChannelHandlerContext ctx) throws Exception { - if (queue.isEmpty()) { - return; - } - ByteBuf buf = null; - try { - ChannelPromise promise = ctx.newPromise(); - int readableBytes = queue.readableBytes(); - buf = queue.remove(readableBytes, promise); - byte[] bytes = new byte[readableBytes]; - buf.readBytes(bytes); - byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length); - ChannelPromise lenPromise = ctx.newPromise(); - ctx.write(ctx.alloc().buffer(4).writeInt(wrapperBytes.length), lenPromise); - ChannelPromise contentPromise = ctx.newPromise(); - ctx.write(Unpooled.wrappedBuffer(wrapperBytes), contentPromise); - PromiseCombiner combiner = new PromiseCombiner(); - combiner.addAll(lenPromise, contentPromise); - combiner.finish(promise); - ctx.flush(); - } finally { - if (buf != null) { - ReferenceCountUtil.safeRelease(buf); - } - } - } - - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - if (!queue.isEmpty()) { - queue.releaseAndFailAll(new ConnectionClosedException("Connection closed")); - } - ctx.close(promise); + protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { + byte[] bytes = new byte[msg.readableBytes()]; + msg.readBytes(bytes); + byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length); + out.ensureWritable(4 + wrapperBytes.length); + out.writeInt(wrapperBytes.length); + out.writeBytes(wrapperBytes); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java index ebc32a827aa..21f70e3f1e4 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java @@ -18,80 +18,31 @@ package org.apache.hadoop.hbase.security; import javax.security.sasl.SaslClient; -import org.apache.hadoop.hbase.exceptions.ConnectionClosedException; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; -import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelOutboundHandlerAdapter; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelPromise; -import org.apache.hbase.thirdparty.io.netty.channel.CoalescingBufferQueue; -import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil; -import org.apache.hbase.thirdparty.io.netty.util.concurrent.PromiseCombiner; +import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder; /** * wrap sasl messages. */ @InterfaceAudience.Private -public class SaslWrapHandler extends ChannelOutboundHandlerAdapter { +public class SaslWrapHandler extends MessageToByteEncoder { private final SaslClient saslClient; - private CoalescingBufferQueue queue; - public SaslWrapHandler(SaslClient saslClient) { this.saslClient = saslClient; } @Override - public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - queue = new CoalescingBufferQueue(ctx.channel()); - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - if (msg instanceof ByteBuf) { - queue.add((ByteBuf) msg, promise); - } else { - ctx.write(msg, promise); - } - } - - @Override - public void flush(ChannelHandlerContext ctx) throws Exception { - if (queue.isEmpty()) { - return; - } - ByteBuf buf = null; - try { - ChannelPromise promise = ctx.newPromise(); - int readableBytes = queue.readableBytes(); - buf = queue.remove(readableBytes, promise); - byte[] bytes = new byte[readableBytes]; - buf.readBytes(bytes); - byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length); - ChannelPromise lenPromise = ctx.newPromise(); - ctx.write(ctx.alloc().buffer(4).writeInt(wrapperBytes.length), lenPromise); - ChannelPromise contentPromise = ctx.newPromise(); - ctx.write(Unpooled.wrappedBuffer(wrapperBytes), contentPromise); - PromiseCombiner combiner = new PromiseCombiner(); - combiner.addAll(lenPromise, contentPromise); - combiner.finish(promise); - ctx.flush(); - } finally { - if (buf != null) { - ReferenceCountUtil.safeRelease(buf); - } - } - } - - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - if (!queue.isEmpty()) { - queue.releaseAndFailAll(new ConnectionClosedException("Connection closed")); - } - ctx.close(promise); + protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { + byte[] bytes = new byte[msg.readableBytes()]; + msg.readBytes(bytes); + byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length); + out.ensureWritable(4 + wrapperBytes.length); + out.writeInt(wrapperBytes.length); + out.writeBytes(wrapperBytes); } } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java index 91468fdd039..53eff7e2ebb 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java @@ -33,6 +33,7 @@ import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescrip import org.apache.hbase.thirdparty.com.google.protobuf.Message; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.channel.Channel; +import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; @@ -60,12 +61,15 @@ class NettyServerRpcConnection extends ServerRpcConnection { void process(final ByteBuf buf) throws IOException, InterruptedException { if (connectionHeaderRead) { - this.callCleanup = buf::release; + this.callCleanup = () -> ReferenceCountUtil.safeRelease(buf); process(new SingleByteBuff(buf.nioBuffer())); } else { ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes()); - buf.readBytes(connectionHeader); - buf.release(); + try { + buf.readBytes(connectionHeader); + } finally { + buf.release(); + } process(connectionHeader); } } @@ -78,9 +82,7 @@ class NettyServerRpcConnection extends ServerRpcConnection { try { if (skipInitialSaslHandshake) { skipInitialSaslHandshake = false; - if (callCleanup != null) { - callCleanup.run(); - } + callCleanupIfNeeded(); return; } @@ -90,9 +92,7 @@ class NettyServerRpcConnection extends ServerRpcConnection { processOneRpc(buf); } } catch (Exception e) { - if (callCleanup != null) { - callCleanup.run(); - } + callCleanupIfNeeded(); throw e; } finally { this.callCleanup = null; @@ -103,7 +103,7 @@ class NettyServerRpcConnection extends ServerRpcConnection { public synchronized void close() { disposeSasl(); channel.close(); - callCleanup = null; + callCleanupIfNeeded(); } @Override diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java index f527d31a314..acafd7f7910 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java @@ -342,6 +342,8 @@ abstract class ServerRpcConnection implements Closeable { } else { plaintextData = saslServer.unwrap(b, 0, b.length); } + // release the request buffer as we have already unwrapped all its content + callCleanupIfNeeded(); processUnwrappedData(plaintextData); } } else { @@ -383,6 +385,9 @@ abstract class ServerRpcConnection implements Closeable { RpcServer.AUDITLOG.warn("{} {}: {}", RpcServer.AUTH_FAILED_FOR, clientIP, saslServer.getAttemptingUser()); throw e; + } finally { + // release the request buffer as we have already unwrapped all its content + callCleanupIfNeeded(); } if (replyToken != null) { if (RpcServer.LOG.isDebugEnabled()) { @@ -412,7 +417,9 @@ abstract class ServerRpcConnection implements Closeable { int count; if (unwrappedDataLengthBuffer.remaining() > 0) { count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer); - if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) return; + if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) { + return; + } } if (unwrappedData == null) { @@ -428,7 +435,9 @@ abstract class ServerRpcConnection implements Closeable { } count = this.rpcServer.channelRead(ch, unwrappedData); - if (count <= 0 || unwrappedData.remaining() > 0) return; + if (count <= 0 || unwrappedData.remaining() > 0) { + return; + } if (unwrappedData.remaining() == 0) { unwrappedDataLengthBuffer.clear(); @@ -732,6 +741,13 @@ abstract class ServerRpcConnection implements Closeable { doRespond(getErrorResponse(msg, e)); } + protected final void callCleanupIfNeeded() { + if (callCleanup != null) { + callCleanup.run(); + callCleanup = null; + } + } + protected final boolean processPreamble(ByteBuffer preambleBuffer) throws IOException { assert preambleBuffer.remaining() == 6; for (int i = 0; i < RPC_HEADER.length; i++) { diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java index f59c002e6bb..ba7a9752a79 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java @@ -284,7 +284,9 @@ class SimpleServerRpcConnection extends ServerRpcConnection { } else { processOneRpc(data); } - + } catch (Exception e) { + callCleanupIfNeeded(); + throw e; } finally { dataLengthBuffer.clear(); // Clean for the next call data = null; // For the GC @@ -296,8 +298,10 @@ class SimpleServerRpcConnection extends ServerRpcConnection { public synchronized void close() { disposeSasl(); data = null; - callCleanup = null; - if (!channel.isOpen()) return; + callCleanupIfNeeded(); + if (!channel.isOpen()) { + return; + } try { socket.shutdownOutput(); } catch (Exception ignored) {