From 0c4263a18be13f412fcf71ebba9715dcc5ca2a12 Mon Sep 17 00:00:00 2001 From: Duo Zhang Date: Wed, 27 Jul 2022 09:00:42 +0800 Subject: [PATCH] HBASE-27185 Rewrite NettyRpcServer to decode rpc request with netty handler (#4624) Signed-off-by: Xin Sun --- .../hadoop/hbase/ipc/NettyRpcConnection.java | 3 - .../security/CryptoAESUnwrapHandler.java | 47 --- .../hbase/security/CryptoAESWrapHandler.java | 48 --- .../NettyHBaseRpcConnectionHeaderHandler.java | 8 +- .../security/NettyHBaseSaslRpcClient.java | 4 +- .../NettyHBaseSaslRpcClientHandler.java | 14 +- .../hbase/security/SaslUnwrapHandler.java | 18 +- .../hbase/security/SaslWrapHandler.java | 14 +- .../ipc/NettyHBaseSaslRpcServerHandler.java | 115 +++++++ .../hbase/ipc/NettyRpcFrameDecoder.java | 21 +- .../hadoop/hbase/ipc/NettyRpcServer.java | 9 +- .../ipc/NettyRpcServerPreambleHandler.java | 27 +- .../ipc/NettyRpcServerRequestDecoder.java | 56 +-- .../hbase/ipc/NettyServerRpcConnection.java | 56 ++- .../apache/hadoop/hbase/ipc/ServerCall.java | 47 +-- .../hadoop/hbase/ipc/ServerRpcConnection.java | 319 +++++++----------- .../hbase/ipc/SimpleRpcServerResponder.java | 28 ++ .../hbase/ipc/SimpleServerRpcConnection.java | 115 +++++++ .../hbase/security/HBaseSaslRpcServer.java | 32 +- 19 files changed, 511 insertions(+), 470 deletions(-) delete mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java delete mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java create mode 100644 hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java index 14e8cbc13d3..d211b1b98e5 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java @@ -223,9 +223,6 @@ class NettyRpcConnection extends RpcConnection { public void operationComplete(Future future) throws Exception { if (future.isSuccess()) { ChannelPipeline p = ch.pipeline(); - p.remove(SaslChallengeDecoder.class); - p.remove(NettyHBaseSaslRpcClientHandler.class); - // check if negotiate with server for connection header is necessary if (saslHandler.isNeedProcessConnectionHeader()) { Promise connectionHeaderPromise = ch.eventLoop().newPromise(); diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java deleted file mode 100644 index 31ed191f91a..00000000000 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hadoop.hbase.security; - -import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; -import org.apache.yetus.audience.InterfaceAudience; - -import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; -import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; -import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; - -/** - * Unwrap messages with Crypto AES. Should be placed after a - * io.netty.handler.codec.LengthFieldBasedFrameDecoder - */ -@InterfaceAudience.Private -public class CryptoAESUnwrapHandler extends SimpleChannelInboundHandler { - - private final CryptoAES cryptoAES; - - public CryptoAESUnwrapHandler(CryptoAES cryptoAES) { - this.cryptoAES = cryptoAES; - } - - @Override - protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { - byte[] bytes = new byte[msg.readableBytes()]; - msg.readBytes(bytes); - ctx.fireChannelRead(Unpooled.wrappedBuffer(cryptoAES.unwrap(bytes, 0, bytes.length))); - } -} diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java deleted file mode 100644 index a99d097ff2d..00000000000 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hadoop.hbase.security; - -import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; -import org.apache.yetus.audience.InterfaceAudience; - -import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; -import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder; - -/** - * wrap messages with Crypto AES. - */ -@InterfaceAudience.Private -public class CryptoAESWrapHandler extends MessageToByteEncoder { - - private final CryptoAES cryptoAES; - - public CryptoAESWrapHandler(CryptoAES cryptoAES) { - this.cryptoAES = cryptoAES; - } - - @Override - protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { - byte[] bytes = new byte[msg.readableBytes()]; - msg.readBytes(bytes); - byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length); - out.ensureWritable(4 + wrapperBytes.length); - out.writeInt(wrapperBytes.length); - out.writeBytes(wrapperBytes); - } -} diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java index a75091c5293..20197912dcb 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java @@ -25,7 +25,6 @@ import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; -import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder; import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; @@ -92,10 +91,7 @@ public class NettyHBaseRpcConnectionHeaderHandler extends SimpleChannelInboundHa * Remove handlers for sasl encryption and add handlers for Crypto AES encryption */ private void setupCryptoAESHandler(ChannelPipeline p, CryptoAES cryptoAES) { - p.remove(SaslWrapHandler.class); - p.remove(SaslUnwrapHandler.class); - String lengthDecoder = p.context(LengthFieldBasedFrameDecoder.class).name(); - p.addAfter(lengthDecoder, null, new CryptoAESUnwrapHandler(cryptoAES)); - p.addAfter(lengthDecoder, null, new CryptoAESWrapHandler(cryptoAES)); + p.replace(SaslWrapHandler.class, null, new SaslWrapHandler(cryptoAES::wrap)); + p.replace(SaslUnwrapHandler.class, null, new SaslUnwrapHandler(cryptoAES::unwrap)); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java index 9b16a41afe4..ede12258ad1 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java @@ -52,9 +52,9 @@ public class NettyHBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { return; } // add wrap and unwrap handlers to pipeline. - p.addFirst(new SaslWrapHandler(saslClient), + p.addFirst(new SaslWrapHandler(saslClient::wrap), new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), - new SaslUnwrapHandler(saslClient)); + new SaslUnwrapHandler(saslClient::unwrap)); } public String getSaslQOP() { diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java index 7473c3269b0..d4d5cb39746 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.exceptions.ConnectionClosedException; import org.apache.hadoop.hbase.ipc.FallbackDisallowedException; import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider; +import org.apache.hadoop.hbase.util.NettyFutureUtils; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; @@ -33,6 +34,7 @@ import org.slf4j.LoggerFactory; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; +import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise; @@ -77,7 +79,7 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< private void writeResponse(ChannelHandlerContext ctx, byte[] response) { LOG.trace("Sending token size={} from initSASLContext.", response.length); - ctx.writeAndFlush( + NettyFutureUtils.safeWriteAndFlush(ctx, ctx.alloc().buffer(4 + response.length).writeInt(response.length).writeBytes(response)); } @@ -90,8 +92,11 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< if (LOG.isTraceEnabled()) { LOG.trace("SASL negotiation for {} is complete", provider.getSaslAuthMethod().getName()); } - + ChannelPipeline p = ctx.pipeline(); saslRpcClient.setupSaslHandler(ctx.pipeline()); + p.remove(SaslChallengeDecoder.class); + p.remove(this); + setCryptoAESOption(); saslPromise.setSuccess(true); @@ -110,6 +115,9 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< @Override public void handlerAdded(ChannelHandlerContext ctx) { + // dispose the saslRpcClient when the channel is closed, since saslRpcClient is final, it is + // safe to reference it in lambda expr. + NettyFutureUtils.addListener(ctx.channel().closeFuture(), f -> saslRpcClient.dispose()); try { byte[] initialResponse = ugi.doAs(new PrivilegedExceptionAction() { @@ -170,14 +178,12 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - saslRpcClient.dispose(); saslPromise.tryFailure(new ConnectionClosedException("Connection closed")); ctx.fireChannelInactive(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - saslRpcClient.dispose(); saslPromise.tryFailure(cause); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUnwrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUnwrapHandler.java index dfc36e4ba31..87e518dae4a 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUnwrapHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUnwrapHandler.java @@ -17,7 +17,7 @@ */ package org.apache.hadoop.hbase.security; -import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; @@ -32,22 +32,20 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; @InterfaceAudience.Private public class SaslUnwrapHandler extends SimpleChannelInboundHandler { - private final SaslClient saslClient; - - public SaslUnwrapHandler(SaslClient saslClient) { - this.saslClient = saslClient; + public interface Unwrapper { + byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException; } - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - SaslUtil.safeDispose(saslClient); - ctx.fireChannelInactive(); + private final Unwrapper unwrapper; + + public SaslUnwrapHandler(Unwrapper unwrapper) { + this.unwrapper = unwrapper; } @Override protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { byte[] bytes = new byte[msg.readableBytes()]; msg.readBytes(bytes); - ctx.fireChannelRead(Unpooled.wrappedBuffer(saslClient.unwrap(bytes, 0, bytes.length))); + ctx.fireChannelRead(Unpooled.wrappedBuffer(unwrapper.unwrap(bytes, 0, bytes.length))); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java index 21f70e3f1e4..6caf2a3e8f5 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslWrapHandler.java @@ -17,7 +17,7 @@ */ package org.apache.hadoop.hbase.security; -import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; @@ -30,17 +30,21 @@ import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder; @InterfaceAudience.Private public class SaslWrapHandler extends MessageToByteEncoder { - private final SaslClient saslClient; + public interface Wrapper { + byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException; + } - public SaslWrapHandler(SaslClient saslClient) { - this.saslClient = saslClient; + private final Wrapper wrapper; + + public SaslWrapHandler(Wrapper wrapper) { + this.wrapper = wrapper; } @Override protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { byte[] bytes = new byte[msg.readableBytes()]; msg.readBytes(bytes); - byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length); + byte[] wrapperBytes = wrapper.wrap(bytes, 0, bytes.length); out.ensureWritable(4 + wrapperBytes.length); out.writeInt(wrapperBytes.length); out.writeBytes(wrapperBytes); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java new file mode 100644 index 00000000000..e36e0b44c74 --- /dev/null +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyHBaseSaslRpcServerHandler.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import java.io.IOException; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.security.SaslStatus; +import org.apache.hadoop.hbase.security.SaslUnwrapHandler; +import org.apache.hadoop.hbase.security.SaslWrapHandler; +import org.apache.hadoop.hbase.util.NettyFutureUtils; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; +import org.apache.hbase.thirdparty.io.netty.buffer.ByteBufOutputStream; +import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; +import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; +import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Implement SASL negotiation logic for rpc server. + */ +class NettyHBaseSaslRpcServerHandler extends SimpleChannelInboundHandler { + + private static final Logger LOG = LoggerFactory.getLogger(NettyHBaseSaslRpcServerHandler.class); + + static final String DECODER_NAME = "SaslNegotiationDecoder"; + + private final NettyRpcServer rpcServer; + + private final NettyServerRpcConnection conn; + + NettyHBaseSaslRpcServerHandler(NettyRpcServer rpcServer, NettyServerRpcConnection conn) { + this.rpcServer = rpcServer; + this.conn = conn; + } + + private void doResponse(ChannelHandlerContext ctx, SaslStatus status, Writable rv, + String errorClass, String error) throws IOException { + // In my testing, have noticed that sasl messages are usually + // in the ballpark of 100-200. That's why the initial capacity is 256. + ByteBuf resp = ctx.alloc().buffer(256); + try (ByteBufOutputStream out = new ByteBufOutputStream(resp)) { + out.writeInt(status.state); // write status + if (status == SaslStatus.SUCCESS) { + rv.write(out); + } else { + WritableUtils.writeString(out, errorClass); + WritableUtils.writeString(out, error); + } + } + NettyFutureUtils.safeWriteAndFlush(ctx, resp); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + LOG.debug("Read input token of size={} for processing by saslServer.evaluateResponse()", + msg.readableBytes()); + HBaseSaslRpcServer saslServer = conn.getOrCreateSaslServer(); + byte[] saslToken = new byte[msg.readableBytes()]; + msg.readBytes(saslToken, 0, saslToken.length); + byte[] replyToken = saslServer.evaluateResponse(saslToken); + if (replyToken != null) { + LOG.debug("Will send token of size {} from saslServer.", replyToken.length); + doResponse(ctx, SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null); + } + if (saslServer.isComplete()) { + conn.finishSaslNegotiation(); + String qop = saslServer.getNegotiatedQop(); + boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + ChannelPipeline p = ctx.pipeline(); + if (useWrap) { + p.addFirst(new SaslWrapHandler(saslServer::wrap)); + p.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), + new SaslUnwrapHandler(saslServer::unwrap)); + } + conn.setupDecoder(); + p.remove(this); + p.remove(DECODER_NAME); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + LOG.error("Error when doing SASL handshade, provider={}", conn.provider, cause); + Throwable sendToClient = HBaseSaslRpcServer.unwrap(cause); + doResponse(ctx, SaslStatus.ERROR, null, sendToClient.getClass().getName(), + sendToClient.getLocalizedMessage()); + rpcServer.metrics.authenticationFailure(); + String clientIP = this.toString(); + // attempting user could be null + RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP, + conn.saslServer != null ? conn.saslServer.getAttemptingUser() : "Unknown"); + NettyFutureUtils.safeClose(ctx); + } +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcFrameDecoder.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcFrameDecoder.java index 164934ac247..551d1d3fb40 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcFrameDecoder.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcFrameDecoder.java @@ -38,21 +38,18 @@ import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; * @since 2.0.0 */ @InterfaceAudience.Private -public class NettyRpcFrameDecoder extends ByteToMessageDecoder { +class NettyRpcFrameDecoder extends ByteToMessageDecoder { private static int FRAME_LENGTH_FIELD_LENGTH = 4; private final int maxFrameLength; + final NettyServerRpcConnection connection; + private boolean requestTooBig; private String requestTooBigMessage; - public NettyRpcFrameDecoder(int maxFrameLength) { + public NettyRpcFrameDecoder(int maxFrameLength, NettyServerRpcConnection connection) { this.maxFrameLength = maxFrameLength; - } - - NettyServerRpcConnection connection; - - void setConnection(NettyServerRpcConnection connection) { this.connection = connection; } @@ -75,10 +72,10 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder { if (frameLength > maxFrameLength) { requestTooBig = true; - requestTooBigMessage = "RPC data length of " + frameLength + " received from " - + connection.getHostAddress() + " is greater than max allowed " - + connection.rpcServer.maxRequestSize + ". Set \"" + SimpleRpcServer.MAX_REQUEST_SIZE - + "\" on server to override this limit (not recommended)"; + requestTooBigMessage = + "RPC data length of " + frameLength + " received from " + connection.getHostAddress() + + " is greater than max allowed " + connection.rpcServer.maxRequestSize + ". Set \"" + + RpcServer.MAX_REQUEST_SIZE + "\" on server to override this limit (not recommended)"; NettyRpcServer.LOG.warn(requestTooBigMessage); @@ -132,7 +129,7 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder { // Make sure the client recognizes the underlying exception // Otherwise, throw a DoNotRetryIOException. if ( - VersionInfoUtil.hasMinimumVersion(connection.connectionHeader.getVersionInfo(), + VersionInfoUtil.hasMinimumVersion(connection.getVersionInfo(), RequestTooBigException.MAJOR_VERSION, RequestTooBigException.MINOR_VERSION) ) { reqTooBig.setResponse(null, null, reqTooBigEx, requestTooBigMessage); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java index 9032e77bf42..8f12b245030 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java @@ -106,11 +106,9 @@ public class NettyRpcServer extends RpcServer { ChannelPipeline pipeline = ch.pipeline(); FixedLengthFrameDecoder preambleDecoder = new FixedLengthFrameDecoder(6); preambleDecoder.setSingleDecode(true); - pipeline.addLast("preambleDecoder", preambleDecoder); - pipeline.addLast("preambleHandler", createNettyRpcServerPreambleHandler()); - pipeline.addLast("frameDecoder", new NettyRpcFrameDecoder(maxRequestSize)); - pipeline.addLast("decoder", new NettyRpcServerRequestDecoder(allChannels, metrics)); - pipeline.addLast("encoder", new NettyRpcServerResponseEncoder(metrics)); + pipeline.addLast(NettyRpcServerPreambleHandler.DECODER_NAME, preambleDecoder); + pipeline.addLast(createNettyRpcServerPreambleHandler(), + new NettyRpcServerResponseEncoder(metrics)); } }); try { @@ -153,6 +151,7 @@ public class NettyRpcServer extends RpcServer { } } + // will be overriden in tests @InterfaceAudience.Private protected NettyRpcServerPreambleHandler createNettyRpcServerPreambleHandler() { return new NettyRpcServerPreambleHandler(NettyRpcServer.this); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java index cf2551e1c08..15a95bc9b09 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerPreambleHandler.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hbase.ipc; import java.nio.ByteBuffer; +import org.apache.hadoop.hbase.util.NettyFutureUtils; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; @@ -25,6 +26,7 @@ import org.apache.hbase.thirdparty.io.netty.channel.Channel; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; +import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder; /** * Handle connection preamble. @@ -33,6 +35,8 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; @InterfaceAudience.Private class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler { + static final String DECODER_NAME = "preambleDecoder"; + private final NettyRpcServer rpcServer; public NettyRpcServerPreambleHandler(NettyRpcServer rpcServer) { @@ -50,12 +54,29 @@ class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler return; } ChannelPipeline p = ctx.pipeline(); - ((NettyRpcFrameDecoder) p.get("frameDecoder")).setConnection(conn); - ((NettyRpcServerRequestDecoder) p.get("decoder")).setConnection(conn); + if (conn.useSasl) { + LengthFieldBasedFrameDecoder decoder = + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4); + decoder.setSingleDecode(true); + p.addLast(NettyHBaseSaslRpcServerHandler.DECODER_NAME, decoder); + p.addLast(new NettyHBaseSaslRpcServerHandler(rpcServer, conn)); + } else { + conn.setupDecoder(); + } + // add first and then remove, so the single decode decoder will pass the remaining bytes to the + // handler above. p.remove(this); - p.remove("preambleDecoder"); + p.remove(DECODER_NAME); } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.", + ctx.channel().remoteAddress(), cause); + NettyFutureUtils.safeClose(ctx); + } + + // will be overridden in tests protected NettyServerRpcConnection createNettyServerRpcConnection(Channel channel) { return new NettyServerRpcConnection(rpcServer, channel); } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerRequestDecoder.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerRequestDecoder.java index cc8b07702b4..2e489e9ab05 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerRequestDecoder.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServerRequestDecoder.java @@ -17,64 +17,42 @@ */ package org.apache.hadoop.hbase.ipc; +import org.apache.hadoop.hbase.util.NettyFutureUtils; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; -import org.apache.hbase.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter; -import org.apache.hbase.thirdparty.io.netty.channel.group.ChannelGroup; +import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; /** * Decoder for rpc request. * @since 2.0.0 */ @InterfaceAudience.Private -class NettyRpcServerRequestDecoder extends ChannelInboundHandlerAdapter { - - private final ChannelGroup allChannels; +class NettyRpcServerRequestDecoder extends SimpleChannelInboundHandler { private final MetricsHBaseServer metrics; - public NettyRpcServerRequestDecoder(ChannelGroup allChannels, MetricsHBaseServer metrics) { - this.allChannels = allChannels; + private final NettyServerRpcConnection connection; + + public NettyRpcServerRequestDecoder(MetricsHBaseServer metrics, + NettyServerRpcConnection connection) { + super(false); this.metrics = metrics; - } - - private NettyServerRpcConnection connection; - - void setConnection(NettyServerRpcConnection connection) { this.connection = connection; } - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - allChannels.add(ctx.channel()); - NettyRpcServer.LOG.trace("Connection {}; # active connections={}", - ctx.channel().remoteAddress(), (allChannels.size() - 1)); - super.channelActive(ctx); - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - ByteBuf input = (ByteBuf) msg; - // 4 bytes length field - metrics.receivedBytes(input.readableBytes() + 4); - connection.process(input); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - allChannels.remove(ctx.channel()); - NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}", - ctx.channel().remoteAddress(), (allChannels.size() - 1)); - super.channelInactive(ctx); - } - @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) { - allChannels.remove(ctx.channel()); - NettyRpcServer.LOG.trace("Connection {}; caught unexpected downstream exception.", + NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.", ctx.channel().remoteAddress(), e); - ctx.channel().close(); + NettyFutureUtils.safeClose(ctx); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + // 4 bytes length field + metrics.receivedBytes(msg.readableBytes() + 4); + connection.process(msg); } } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java index 58be1376953..60db16d77e0 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyServerRpcConnection.java @@ -20,12 +20,12 @@ package org.apache.hadoop.hbase.ipc; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.nio.ByteBuffer; import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.nio.ByteBuff; import org.apache.hadoop.hbase.nio.SingleByteBuff; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hadoop.hbase.util.NettyFutureUtils; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; @@ -33,7 +33,7 @@ import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescrip import org.apache.hbase.thirdparty.com.google.protobuf.Message; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.channel.Channel; -import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil; +import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; @@ -49,10 +49,16 @@ class NettyServerRpcConnection extends ServerRpcConnection { NettyServerRpcConnection(NettyRpcServer rpcServer, Channel channel) { super(rpcServer); this.channel = channel; + rpcServer.allChannels.add(channel); + NettyRpcServer.LOG.trace("Connection {}; # active connections={}", channel.remoteAddress(), + rpcServer.allChannels.size() - 1); // register close hook to release resources - channel.closeFuture().addListener(f -> { + NettyFutureUtils.addListener(channel.closeFuture(), f -> { disposeSasl(); callCleanupIfNeeded(); + NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}", channel.remoteAddress(), + rpcServer.allChannels.size() - 1); + rpcServer.allChannels.remove(channel); }); InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress()); this.addr = inetSocketAddress.getAddress(); @@ -64,38 +70,22 @@ class NettyServerRpcConnection extends ServerRpcConnection { this.remotePort = inetSocketAddress.getPort(); } - void process(final ByteBuf buf) throws IOException, InterruptedException { - if (connectionHeaderRead) { - this.callCleanup = () -> ReferenceCountUtil.safeRelease(buf); - process(new SingleByteBuff(buf.nioBuffer())); - } else { - ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes()); - try { - buf.readBytes(connectionHeader); - } finally { - buf.release(); - } - process(connectionHeader); + void setupDecoder() { + ChannelPipeline p = channel.pipeline(); + p.addLast("frameDecoder", new NettyRpcFrameDecoder(rpcServer.maxRequestSize, this)); + p.addLast("decoder", new NettyRpcServerRequestDecoder(rpcServer.metrics, this)); + } + + void process(ByteBuf buf) throws IOException, InterruptedException { + if (skipInitialSaslHandshake) { + skipInitialSaslHandshake = false; + buf.release(); + return; } - } - - void process(ByteBuffer buf) throws IOException, InterruptedException { - process(new SingleByteBuff(buf)); - } - - void process(ByteBuff buf) throws IOException, InterruptedException { + this.callCleanup = () -> buf.release(); + ByteBuff byteBuff = new SingleByteBuff(buf.nioBuffer()); try { - if (skipInitialSaslHandshake) { - skipInitialSaslHandshake = false; - callCleanupIfNeeded(); - return; - } - - if (useSasl) { - saslReadAndProcess(buf); - } else { - processOneRpc(buf); - } + processOneRpc(byteBuff); } catch (Exception e) { callCleanupIfNeeded(); throw e; diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerCall.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerCall.java index 44a7a74006a..bdd3593cf2d 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerCall.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerCall.java @@ -399,37 +399,6 @@ public abstract class ServerCall implements RpcCa return pbBuf; } - protected BufferChain wrapWithSasl(BufferChain bc) throws IOException { - if (!this.connection.useSasl) { - return bc; - } - // Looks like no way around this; saslserver wants a byte array. I have to make it one. - // THIS IS A BIG UGLY COPY. - byte[] responseBytes = bc.getBytes(); - byte[] token; - // synchronization may be needed since there can be multiple Handler - // threads using saslServer or Crypto AES to wrap responses. - if (connection.useCryptoAesWrap) { - // wrap with Crypto AES - synchronized (connection.cryptoAES) { - token = connection.cryptoAES.wrap(responseBytes, 0, responseBytes.length); - } - } else { - synchronized (connection.saslServer) { - token = connection.saslServer.wrap(responseBytes, 0, responseBytes.length); - } - } - if (RpcServer.LOG.isTraceEnabled()) { - RpcServer.LOG - .trace("Adding saslServer wrapped token of size " + token.length + " as call response."); - } - - ByteBuffer[] responseBufs = new ByteBuffer[2]; - responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length)); - responseBufs[1] = ByteBuffer.wrap(token); - return new BufferChain(responseBufs); - } - @Override public long disconnectSince() { if (!this.connection.isConnectionOpen()) { @@ -556,20 +525,6 @@ public abstract class ServerCall implements RpcCa @Override public synchronized BufferChain getResponse() { - if (connection.useWrap) { - /* - * wrapping result with SASL as the last step just before sending it out, so every message - * must have the right increasing sequence number - */ - try { - return wrapWithSasl(response); - } catch (IOException e) { - /* it is exactly the same what setResponse() does */ - RpcServer.LOG.warn("Exception while creating response " + e); - return null; - } - } else { - return response; - } + return response; } } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java index e5f01adff8c..efb6630ad9e 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/ServerRpcConnection.java @@ -24,15 +24,12 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.opentelemetry.context.propagation.TextMapGetter; -import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.DataOutputStream; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; import java.security.GeneralSecurityException; import java.util.Objects; import java.util.Properties; @@ -47,7 +44,6 @@ import org.apache.hadoop.hbase.io.ByteBufferOutputStream; import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.nio.ByteBuff; -import org.apache.hadoop.hbase.nio.SingleByteBuff; import org.apache.hadoop.hbase.security.AccessDeniedException; import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; import org.apache.hadoop.hbase.security.SaslStatus; @@ -58,7 +54,7 @@ import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvide import org.apache.hadoop.hbase.security.provider.SimpleSaslServerAuthenticationProvider; import org.apache.hadoop.hbase.trace.TraceUtil; import org.apache.hadoop.hbase.util.Bytes; -import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.hbase.util.Pair; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableUtils; @@ -67,7 +63,6 @@ import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.ProxyUsers; -import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; @@ -120,16 +115,9 @@ abstract class ServerRpcConnection implements Closeable { protected BlockingService service; protected SaslServerAuthenticationProvider provider; - protected boolean saslContextEstablished; protected boolean skipInitialSaslHandshake; - private ByteBuffer unwrappedData; - // When is this set? FindBugs wants to know! Says NP - private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); protected boolean useSasl; protected HBaseSaslRpcServer saslServer; - protected CryptoAES cryptoAES; - protected boolean useWrap = false; - protected boolean useCryptoAesWrap = false; // was authentication allowed with a fallback to simple auth protected boolean authenticatedWithFallback; @@ -164,7 +152,7 @@ abstract class ServerRpcConnection implements Closeable { } public VersionInfo getVersionInfo() { - if (connectionHeader.hasVersionInfo()) { + if (connectionHeader != null && connectionHeader.hasVersionInfo()) { return connectionHeader.getVersionInfo(); } return null; @@ -181,18 +169,24 @@ abstract class ServerRpcConnection implements Closeable { /** * Set up cell block codecs n */ - private void setupCellBlockCodecs(final ConnectionHeader header) throws FatalConnectionException { + private void setupCellBlockCodecs() throws FatalConnectionException { // TODO: Plug in other supported decoders. - if (!header.hasCellBlockCodecClass()) return; - String className = header.getCellBlockCodecClass(); - if (className == null || className.length() == 0) return; + if (!connectionHeader.hasCellBlockCodecClass()) { + return; + } + String className = connectionHeader.getCellBlockCodecClass(); + if (className == null || className.length() == 0) { + return; + } try { this.codec = (Codec) Class.forName(className).getDeclaredConstructor().newInstance(); } catch (Exception e) { throw new UnsupportedCellCodecException(className, e); } - if (!header.hasCellBlockCompressorClass()) return; - className = header.getCellBlockCompressorClass(); + if (!connectionHeader.hasCellBlockCompressorClass()) { + return; + } + className = connectionHeader.getCellBlockCompressorClass(); try { this.compressionCodec = (CompressionCodec) Class.forName(className).getDeclaredConstructor().newInstance(); @@ -202,21 +196,29 @@ abstract class ServerRpcConnection implements Closeable { } /** - * Set up cipher for rpc encryption with Apache Commons Crypto n + * Set up cipher for rpc encryption with Apache Commons Crypto. */ - private void setupCryptoCipher(final ConnectionHeader header, - RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) throws FatalConnectionException { + private Pair setupCryptoCipher() + throws FatalConnectionException { // If simple auth, return - if (saslServer == null) return; + if (saslServer == null) { + return null; + } // check if rpc encryption with Crypto AES String qop = saslServer.getNegotiatedQop(); boolean isEncryption = SaslUtil.QualityOfProtection.PRIVACY.getSaslQop().equalsIgnoreCase(qop); boolean isCryptoAesEncryption = isEncryption && this.rpcServer.conf.getBoolean("hbase.rpc.crypto.encryption.aes.enabled", false); - if (!isCryptoAesEncryption) return; - if (!header.hasRpcCryptoCipherTransformation()) return; - String transformation = header.getRpcCryptoCipherTransformation(); - if (transformation == null || transformation.length() == 0) return; + if (!isCryptoAesEncryption) { + return null; + } + if (!connectionHeader.hasRpcCryptoCipherTransformation()) { + return null; + } + String transformation = connectionHeader.getRpcCryptoCipherTransformation(); + if (transformation == null || transformation.length() == 0) { + return null; + } // Negotiates AES based on complete saslServer. // The Crypto metadata need to be encrypted and send to client. Properties properties = new Properties(); @@ -242,6 +244,7 @@ abstract class ServerRpcConnection implements Closeable { byte[] inIv = new byte[len]; byte[] outIv = new byte[len]; + CryptoAES cryptoAES; try { // generate the cipher meta data with SecureRandom CryptoRandom secureRandom = CryptoRandomFactory.getCryptoRandom(properties); @@ -252,19 +255,20 @@ abstract class ServerRpcConnection implements Closeable { // create CryptoAES for server cryptoAES = new CryptoAES(transformation, properties, inKey, outKey, inIv, outIv); - // create SaslCipherMeta and send to client, - // for client, the [inKey, outKey], [inIv, outIv] should be reversed - RPCProtos.CryptoCipherMeta.Builder ccmBuilder = RPCProtos.CryptoCipherMeta.newBuilder(); - ccmBuilder.setTransformation(transformation); - ccmBuilder.setInIv(getByteString(outIv)); - ccmBuilder.setInKey(getByteString(outKey)); - ccmBuilder.setOutIv(getByteString(inIv)); - ccmBuilder.setOutKey(getByteString(inKey)); - chrBuilder.setCryptoCipherMeta(ccmBuilder); - useCryptoAesWrap = true; } catch (GeneralSecurityException | IOException ex) { throw new UnsupportedCryptoException(ex.getMessage(), ex); } + // create SaslCipherMeta and send to client, + // for client, the [inKey, outKey], [inIv, outIv] should be reversed + RPCProtos.CryptoCipherMeta.Builder ccmBuilder = RPCProtos.CryptoCipherMeta.newBuilder(); + ccmBuilder.setTransformation(transformation); + ccmBuilder.setInIv(getByteString(outIv)); + ccmBuilder.setInKey(getByteString(outKey)); + ccmBuilder.setOutIv(getByteString(inIv)); + ccmBuilder.setOutKey(getByteString(inKey)); + RPCProtos.ConnectionHeaderResponse resp = + RPCProtos.ConnectionHeaderResponse.newBuilder().setCryptoCipherMeta(ccmBuilder).build(); + return Pair.newPair(resp, cryptoAES); } private ByteString getByteString(byte[] bytes) { @@ -327,125 +331,20 @@ abstract class ServerRpcConnection implements Closeable { doRespond(() -> bc); } - public void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException { - if (saslContextEstablished) { - RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()", - saslToken.limit()); - if (!useWrap) { - processOneRpc(saslToken); - } else { - byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes(); - byte[] plaintextData; - if (useCryptoAesWrap) { - // unwrap with CryptoAES - plaintextData = cryptoAES.unwrap(b, 0, b.length); - } else { - plaintextData = saslServer.unwrap(b, 0, b.length); - } - // release the request buffer as we have already unwrapped all its content - callCleanupIfNeeded(); - processUnwrappedData(plaintextData); - } - } else { - byte[] replyToken; - try { - if (saslServer == null) { - try { - saslServer = - new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager); - } catch (Exception e) { - RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer " - + "with sasl provider: " + provider, e); - throw e; - } - RpcServer.LOG.debug("Created SASL server with mechanism={}", - provider.getSaslAuthMethod().getAuthMethod()); - } - RpcServer.LOG.debug( - "Read input token of size={} for processing by saslServer." + "evaluateResponse()", - saslToken.limit()); - replyToken = saslServer - .evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes()); - } catch (IOException e) { - RpcServer.LOG.debug("Failed to execute SASL handshake", e); - IOException sendToClient = e; - Throwable cause = e; - while (cause != null) { - if (cause instanceof InvalidToken) { - sendToClient = (InvalidToken) cause; - break; - } - cause = cause.getCause(); - } - doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), - sendToClient.getLocalizedMessage()); - this.rpcServer.metrics.authenticationFailure(); - String clientIP = this.toString(); - // attempting user could be null - RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP, - saslServer.getAttemptingUser()); - throw e; - } finally { - // release the request buffer as we have already unwrapped all its content - callCleanupIfNeeded(); - } - if (replyToken != null) { - if (RpcServer.LOG.isDebugEnabled()) { - RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer."); - } - doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null); - } - if (saslServer.isComplete()) { - String qop = saslServer.getNegotiatedQop(); - useWrap = qop != null && !"auth".equalsIgnoreCase(qop); - ugi = - provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager); - RpcServer.LOG.debug( - "SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi, - qop); - this.rpcServer.metrics.authenticationSuccess(); - RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi); - saslContextEstablished = true; - } + HBaseSaslRpcServer getOrCreateSaslServer() throws IOException { + if (saslServer == null) { + saslServer = new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager); } + return saslServer; } - private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException { - ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); - // Read all RPCs contained in the inBuf, even partial ones - while (true) { - int count; - if (unwrappedDataLengthBuffer.remaining() > 0) { - count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer); - if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) { - return; - } - } - - if (unwrappedData == null) { - unwrappedDataLengthBuffer.flip(); - int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); - - if (unwrappedDataLength == RpcClient.PING_CALL_ID) { - if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message"); - unwrappedDataLengthBuffer.clear(); - continue; // ping message - } - unwrappedData = ByteBuffer.allocate(unwrappedDataLength); - } - - count = this.rpcServer.channelRead(ch, unwrappedData); - if (count <= 0 || unwrappedData.remaining() > 0) { - return; - } - - if (unwrappedData.remaining() == 0) { - unwrappedDataLengthBuffer.clear(); - unwrappedData.flip(); - processOneRpc(new SingleByteBuff(unwrappedData)); - unwrappedData = null; - } - } + void finishSaslNegotiation() throws IOException { + String qop = saslServer.getNegotiatedQop(); + ugi = provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager); + RpcServer.LOG.debug( + "SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi, qop); + rpcServer.metrics.authenticationSuccess(); + RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi); } public void processOneRpc(ByteBuff buf) throws IOException, InterruptedException { @@ -453,6 +352,7 @@ abstract class ServerRpcConnection implements Closeable { processRequest(buf); } else { processConnectionHeader(buf); + callCleanupIfNeeded(); this.connectionHeaderRead = true; if (rpcServer.needAuthorization() && !authorizeConnection()) { // Throw FatalConnectionException wrapping ACE so client does right thing and closes @@ -486,25 +386,35 @@ abstract class ServerRpcConnection implements Closeable { return true; } + private CodedInputStream createCis(ByteBuff buf) { + // Here we read in the header. We avoid having pb + // do its default 4k allocation for CodedInputStream. We force it to use + // backing array. + CodedInputStream cis; + if (buf.hasArray()) { + cis = UnsafeByteOperations + .unsafeWrap(buf.array(), buf.arrayOffset() + buf.position(), buf.limit()).newCodedInput(); + } else { + cis = UnsafeByteOperations.unsafeWrap(new ByteBuffByteInput(buf, buf.limit()), 0, buf.limit()) + .newCodedInput(); + } + cis.enableAliasing(true); + return cis; + } + // Reads the connection header following version private void processConnectionHeader(ByteBuff buf) throws IOException { - if (buf.hasArray()) { - this.connectionHeader = ConnectionHeader.parseFrom(buf.array()); - } else { - CodedInputStream cis = UnsafeByteOperations - .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); - cis.enableAliasing(true); - this.connectionHeader = ConnectionHeader.parseFrom(cis); - } + this.connectionHeader = ConnectionHeader.parseFrom(createCis(buf)); String serviceName = connectionHeader.getServiceName(); - if (serviceName == null) throw new EmptyServiceNameException(); + if (serviceName == null) { + throw new EmptyServiceNameException(); + } this.service = RpcServer.getService(this.rpcServer.services, serviceName); - if (this.service == null) throw new UnknownServiceException(serviceName); - setupCellBlockCodecs(this.connectionHeader); - RPCProtos.ConnectionHeaderResponse.Builder chrBuilder = - RPCProtos.ConnectionHeaderResponse.newBuilder(); - setupCryptoCipher(this.connectionHeader, chrBuilder); - responseConnectionHeader(chrBuilder); + if (this.service == null) { + throw new UnknownServiceException(serviceName); + } + setupCellBlockCodecs(); + sendConnectionHeaderResponseIfNeeded(); UserGroupInformation protocolUser = createUser(connectionHeader); if (!useSasl) { ugi = protocolUser; @@ -553,25 +463,35 @@ abstract class ServerRpcConnection implements Closeable { /** * Send the response for connection header */ - private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) - throws FatalConnectionException { + private void sendConnectionHeaderResponseIfNeeded() throws FatalConnectionException { + Pair pair = setupCryptoCipher(); // Response the connection header if Crypto AES is enabled - if (!chrBuilder.hasCryptoCipherMeta()) return; + if (pair == null) { + return; + } try { - byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray(); - // encrypt the Crypto AES cipher meta data with sasl server, and send to client - byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4]; - Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4); - Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length); - byte[] wrapped = saslServer.wrap(unwrapped, 0, unwrapped.length); + int size = pair.getFirst().getSerializedSize(); BufferChain bc; - try (ByteBufferOutputStream response = new ByteBufferOutputStream(wrapped.length + 4); - DataOutputStream out = new DataOutputStream(response)) { - out.writeInt(wrapped.length); - out.write(wrapped); - bc = new BufferChain(response.getByteBuffer()); + try (ByteBufferOutputStream bbOut = new ByteBufferOutputStream(4 + size); + DataOutputStream out = new DataOutputStream(bbOut)) { + out.writeInt(size); + pair.getFirst().writeTo(out); + bc = new BufferChain(bbOut.getByteBuffer()); } - doRespond(() -> bc); + doRespond(new RpcResponse() { + + @Override + public BufferChain getResponse() { + return bc; + } + + @Override + public void done() { + // must switch after sending the connection header response, as the client still uses the + // original SaslClient to unwrap the data we send back + saslServer.switchToCryptoAES(pair.getSecond()); + } + }); } catch (IOException ex) { throw new UnsupportedCryptoException(ex.getMessage(), ex); } @@ -581,7 +501,9 @@ abstract class ServerRpcConnection implements Closeable { /** * n * Has the request header and the request param and optionally encoded data buffer all in this - * one array. nn + * one array. + *

+ * Will be overridden in tests. */ protected void processRequest(ByteBuff buf) throws IOException, InterruptedException { long totalRequestSize = buf.limit(); @@ -589,14 +511,7 @@ abstract class ServerRpcConnection implements Closeable { // Here we read in the header. We avoid having pb // do its default 4k allocation for CodedInputStream. We force it to use // backing array. - CodedInputStream cis; - if (buf.hasArray()) { - cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()).newCodedInput(); - } else { - cis = UnsafeByteOperations - .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); - } - cis.enableAliasing(true); + CodedInputStream cis = createCis(buf); int headerSize = cis.readRawVarint32(); offset = cis.getTotalBytesRead(); Message.Builder builder = RequestHeader.newBuilder(); @@ -737,7 +652,7 @@ abstract class ServerRpcConnection implements Closeable { } private void doBadPreambleHandling(String msg, Exception e) throws IOException { - SimpleRpcServer.LOG.warn(msg); + RpcServer.LOG.warn(msg); doRespond(getErrorResponse(msg, e)); } @@ -762,7 +677,7 @@ abstract class ServerRpcConnection implements Closeable { int version = preambleBuffer.get() & 0xFF; byte authbyte = preambleBuffer.get(); - if (version != SimpleRpcServer.CURRENT_VERSION) { + if (version != RpcServer.CURRENT_VERSION) { String msg = getFatalConnectionString(version, authbyte); doBadPreambleHandling(msg, new WrongVersionException(msg)); return false; @@ -810,34 +725,28 @@ abstract class ServerRpcConnection implements Closeable { private static class ByteBuffByteInput extends ByteInput { private ByteBuff buf; - private int offset; private int length; - ByteBuffByteInput(ByteBuff buf, int offset, int length) { + ByteBuffByteInput(ByteBuff buf, int length) { this.buf = buf; - this.offset = offset; this.length = length; } @Override public byte read(int offset) { - return this.buf.get(getAbsoluteOffset(offset)); - } - - private int getAbsoluteOffset(int offset) { - return this.offset + offset; + return this.buf.get(offset); } @Override public int read(int offset, byte[] out, int outOffset, int len) { - this.buf.get(getAbsoluteOffset(offset), out, outOffset, len); + this.buf.get(offset, out, outOffset, len); return len; } @Override public int read(int offset, ByteBuffer out) { int len = out.remaining(); - this.buf.get(out, getAbsoluteOffset(offset), len); + this.buf.get(out, offset, len); return len; } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServerResponder.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServerResponder.java index b9d8d3dffc4..db1b380361d 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServerResponder.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServerResponder.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hbase.ipc; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; @@ -28,6 +29,8 @@ import java.util.Iterator; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.apache.hadoop.hbase.HBaseIOException; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.Threads; import org.apache.hadoop.util.StringUtils; @@ -217,6 +220,28 @@ class SimpleRpcServerResponder extends Thread { } } + private BufferChain wrapWithSasl(HBaseSaslRpcServer saslServer, BufferChain bc) + throws IOException { + // Looks like no way around this; saslserver wants a byte array. I have to make it one. + // THIS IS A BIG UGLY COPY. + byte[] responseBytes = bc.getBytes(); + byte[] token; + // synchronization may be needed since there can be multiple Handler + // threads using saslServer or Crypto AES to wrap responses. + synchronized (saslServer) { + token = saslServer.wrap(responseBytes, 0, responseBytes.length); + } + if (SimpleRpcServer.LOG.isTraceEnabled()) { + SimpleRpcServer.LOG + .trace("Adding saslServer wrapped token of size " + token.length + " as call response."); + } + + ByteBuffer[] responseBufs = new ByteBuffer[2]; + responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length)); + responseBufs[1] = ByteBuffer.wrap(token); + return new BufferChain(responseBufs); + } + /** * Process the response for this call. You need to have the lock on * {@link org.apache.hadoop.hbase.ipc.SimpleServerRpcConnection#responseWriteLock} @@ -226,6 +251,9 @@ class SimpleRpcServerResponder extends Thread { throws IOException { boolean error = true; BufferChain buf = resp.getResponse(); + if (conn.useWrap) { + buf = wrapWithSasl(conn.saslServer, buf); + } try { // Send as much data as we can in the non-blocking fashion long numBytes = this.simpleRpcServer.channelWrite(conn.channel, buf); diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java index 51e1bedba57..4c8925d7274 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java @@ -17,11 +17,13 @@ */ package org.apache.hadoop.hbase.ipc; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.net.InetAddress; import java.net.Socket; import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentLinkedDeque; @@ -34,7 +36,11 @@ import org.apache.hadoop.hbase.client.VersionInfoUtil; import org.apache.hadoop.hbase.exceptions.RequestTooBigException; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.nio.ByteBuff; +import org.apache.hadoop.hbase.nio.SingleByteBuff; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.security.SaslStatus; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hadoop.io.BytesWritable; import org.apache.yetus.audience.InterfaceAudience; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; @@ -63,6 +69,11 @@ class SimpleServerRpcConnection extends ServerRpcConnection { // If initial preamble with version and magic has been read or not. private boolean connectionPreambleRead = false; + private boolean saslContextEstablished; + private ByteBuffer unwrappedData; + // When is this set? FindBugs wants to know! Says NP + private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); + boolean useWrap = false; final ConcurrentLinkedDeque responseQueue = new ConcurrentLinkedDeque<>(); final Lock responseWriteLock = new ReentrantLock(); @@ -142,6 +153,110 @@ class SimpleServerRpcConnection extends ServerRpcConnection { } } + private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException { + ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); + // Read all RPCs contained in the inBuf, even partial ones + while (true) { + int count; + if (unwrappedDataLengthBuffer.remaining() > 0) { + count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer); + if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) { + return; + } + } + + if (unwrappedData == null) { + unwrappedDataLengthBuffer.flip(); + int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); + + if (unwrappedDataLength == RpcClient.PING_CALL_ID) { + if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message"); + unwrappedDataLengthBuffer.clear(); + continue; // ping message + } + unwrappedData = ByteBuffer.allocate(unwrappedDataLength); + } + + count = this.rpcServer.channelRead(ch, unwrappedData); + if (count <= 0 || unwrappedData.remaining() > 0) { + return; + } + + if (unwrappedData.remaining() == 0) { + unwrappedDataLengthBuffer.clear(); + unwrappedData.flip(); + processOneRpc(new SingleByteBuff(unwrappedData)); + unwrappedData = null; + } + } + } + + private void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException { + if (saslContextEstablished) { + RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()", + saslToken.limit()); + if (!useWrap) { + processOneRpc(saslToken); + } else { + byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes(); + byte[] plaintextData = saslServer.unwrap(b, 0, b.length); + // release the request buffer as we have already unwrapped all its content + callCleanupIfNeeded(); + processUnwrappedData(plaintextData); + } + } else { + byte[] replyToken; + try { + try { + getOrCreateSaslServer(); + } catch (Exception e) { + RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer " + + "with sasl provider: " + provider, e); + throw e; + } + RpcServer.LOG.debug("Created SASL server with mechanism={}", + provider.getSaslAuthMethod().getAuthMethod()); + RpcServer.LOG.debug( + "Read input token of size={} for processing by saslServer." + "evaluateResponse()", + saslToken.limit()); + replyToken = saslServer + .evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes()); + } catch (IOException e) { + RpcServer.LOG.debug("Failed to execute SASL handshake", e); + Throwable sendToClient = HBaseSaslRpcServer.unwrap(e); + doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), + sendToClient.getLocalizedMessage()); + this.rpcServer.metrics.authenticationFailure(); + String clientIP = this.toString(); + // attempting user could be null + RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP, + saslServer.getAttemptingUser()); + throw e; + } finally { + // release the request buffer as we have already unwrapped all its content + callCleanupIfNeeded(); + } + if (replyToken != null) { + if (RpcServer.LOG.isDebugEnabled()) { + RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer."); + } + doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null); + } + if (saslServer.isComplete()) { + String qop = saslServer.getNegotiatedQop(); + useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + ugi = + provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager); + RpcServer.LOG.debug( + "SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi, + qop); + this.rpcServer.metrics.authenticationSuccess(); + RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi); + saslContextEstablished = true; + } + } + } + /** * Read off the wire. If there is not enough data to read, update the connection state with what * we have and returns. diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcServer.java index eb9913174e2..6d375e0014a 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcServer.java @@ -24,6 +24,7 @@ import java.util.Map; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.security.provider.AttemptingUserProvidingSaslServer; import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvider; import org.apache.hadoop.security.token.SecretManager; @@ -40,6 +41,7 @@ public class HBaseSaslRpcServer { private final AttemptingUserProvidingSaslServer serverWithProvider; private final SaslServer saslServer; + private CryptoAES cryptoAES; public HBaseSaslRpcServer(SaslServerAuthenticationProvider provider, Map saslProps, SecretManager secretManager) @@ -61,16 +63,28 @@ public class HBaseSaslRpcServer { SaslUtil.safeDispose(saslServer); } + public void switchToCryptoAES(CryptoAES cryptoAES) { + this.cryptoAES = cryptoAES; + } + public String getAttemptingUser() { return serverWithProvider.getAttemptingUser().map(Object::toString).orElse("Unknown"); } public byte[] wrap(byte[] buf, int off, int len) throws SaslException { - return saslServer.wrap(buf, off, len); + if (cryptoAES != null) { + return cryptoAES.wrap(buf, off, len); + } else { + return saslServer.wrap(buf, off, len); + } } public byte[] unwrap(byte[] buf, int off, int len) throws SaslException { - return saslServer.unwrap(buf, off, len); + if (cryptoAES != null) { + return cryptoAES.unwrap(buf, off, len); + } else { + return saslServer.unwrap(buf, off, len); + } } public String getNegotiatedQop() { @@ -92,4 +106,18 @@ public class HBaseSaslRpcServer { } return tokenIdentifier; } + + /** + * Unwrap InvalidToken exception, otherwise return the one passed in. + */ + public static Throwable unwrap(Throwable e) { + Throwable cause = e; + while (cause != null) { + if (cause instanceof InvalidToken) { + return cause; + } + cause = cause.getCause(); + } + return e; + } }