HBASE-27185 Rewrite NettyRpcServer to decode rpc request with netty handler (#4624)

Signed-off-by: Xin Sun <ddupgs@gmail.com>
This commit is contained in:
Duo Zhang 2022-07-27 09:00:42 +08:00 committed by GitHub
parent ac8b3a795f
commit 0c4263a18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 511 additions and 470 deletions

View File

@ -223,9 +223,6 @@ class NettyRpcConnection extends RpcConnection {
public void operationComplete(Future<Boolean> future) throws Exception {
if (future.isSuccess()) {
ChannelPipeline p = ch.pipeline();
p.remove(SaslChallengeDecoder.class);
p.remove(NettyHBaseSaslRpcClientHandler.class);
// check if negotiate with server for connection header is necessary
if (saslHandler.isNeedProcessConnectionHeader()) {
Promise<Boolean> connectionHeaderPromise = ch.eventLoop().newPromise();

View File

@ -1,47 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
/**
* Unwrap messages with Crypto AES. Should be placed after a
* io.netty.handler.codec.LengthFieldBasedFrameDecoder
*/
@InterfaceAudience.Private
public class CryptoAESUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final CryptoAES cryptoAES;
public CryptoAESUnwrapHandler(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
ctx.fireChannelRead(Unpooled.wrappedBuffer(cryptoAES.unwrap(bytes, 0, bytes.length)));
}
}

View File

@ -1,48 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
/**
* wrap messages with Crypto AES.
*/
@InterfaceAudience.Private
public class CryptoAESWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final CryptoAES cryptoAES;
public CryptoAESWrapHandler(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
@Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length);
out.ensureWritable(4 + wrapperBytes.length);
out.writeInt(wrapperBytes.length);
out.writeBytes(wrapperBytes);
}
}

View File

@ -25,7 +25,6 @@ import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
@ -92,10 +91,7 @@ public class NettyHBaseRpcConnectionHeaderHandler extends SimpleChannelInboundHa
* Remove handlers for sasl encryption and add handlers for Crypto AES encryption
*/
private void setupCryptoAESHandler(ChannelPipeline p, CryptoAES cryptoAES) {
p.remove(SaslWrapHandler.class);
p.remove(SaslUnwrapHandler.class);
String lengthDecoder = p.context(LengthFieldBasedFrameDecoder.class).name();
p.addAfter(lengthDecoder, null, new CryptoAESUnwrapHandler(cryptoAES));
p.addAfter(lengthDecoder, null, new CryptoAESWrapHandler(cryptoAES));
p.replace(SaslWrapHandler.class, null, new SaslWrapHandler(cryptoAES::wrap));
p.replace(SaslUnwrapHandler.class, null, new SaslUnwrapHandler(cryptoAES::unwrap));
}
}

View File

@ -52,9 +52,9 @@ public class NettyHBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
return;
}
// add wrap and unwrap handlers to pipeline.
p.addFirst(new SaslWrapHandler(saslClient),
p.addFirst(new SaslWrapHandler(saslClient::wrap),
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
new SaslUnwrapHandler(saslClient));
new SaslUnwrapHandler(saslClient::unwrap));
}
public String getSaslQOP() {

View File

@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.exceptions.ConnectionClosedException;
import org.apache.hadoop.hbase.ipc.FallbackDisallowedException;
import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
@ -33,6 +34,7 @@ import org.slf4j.LoggerFactory;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise;
@ -77,7 +79,7 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
private void writeResponse(ChannelHandlerContext ctx, byte[] response) {
LOG.trace("Sending token size={} from initSASLContext.", response.length);
ctx.writeAndFlush(
NettyFutureUtils.safeWriteAndFlush(ctx,
ctx.alloc().buffer(4 + response.length).writeInt(response.length).writeBytes(response));
}
@ -90,8 +92,11 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
if (LOG.isTraceEnabled()) {
LOG.trace("SASL negotiation for {} is complete", provider.getSaslAuthMethod().getName());
}
ChannelPipeline p = ctx.pipeline();
saslRpcClient.setupSaslHandler(ctx.pipeline());
p.remove(SaslChallengeDecoder.class);
p.remove(this);
setCryptoAESOption();
saslPromise.setSuccess(true);
@ -110,6 +115,9 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
// dispose the saslRpcClient when the channel is closed, since saslRpcClient is final, it is
// safe to reference it in lambda expr.
NettyFutureUtils.addListener(ctx.channel().closeFuture(), f -> saslRpcClient.dispose());
try {
byte[] initialResponse = ugi.doAs(new PrivilegedExceptionAction<byte[]>() {
@ -170,14 +178,12 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
saslRpcClient.dispose();
saslPromise.tryFailure(new ConnectionClosedException("Connection closed"));
ctx.fireChannelInactive();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
saslRpcClient.dispose();
saslPromise.tryFailure(cause);
}
}

View File

@ -17,7 +17,7 @@
*/
package org.apache.hadoop.hbase.security;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -32,22 +32,20 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
@InterfaceAudience.Private
public class SaslUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final SaslClient saslClient;
public SaslUnwrapHandler(SaslClient saslClient) {
this.saslClient = saslClient;
public interface Unwrapper {
byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
SaslUtil.safeDispose(saslClient);
ctx.fireChannelInactive();
private final Unwrapper unwrapper;
public SaslUnwrapHandler(Unwrapper unwrapper) {
this.unwrapper = unwrapper;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
ctx.fireChannelRead(Unpooled.wrappedBuffer(saslClient.unwrap(bytes, 0, bytes.length)));
ctx.fireChannelRead(Unpooled.wrappedBuffer(unwrapper.unwrap(bytes, 0, bytes.length)));
}
}

View File

@ -17,7 +17,7 @@
*/
package org.apache.hadoop.hbase.security;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -30,17 +30,21 @@ import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
@InterfaceAudience.Private
public class SaslWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final SaslClient saslClient;
public interface Wrapper {
byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException;
}
public SaslWrapHandler(SaslClient saslClient) {
this.saslClient = saslClient;
private final Wrapper wrapper;
public SaslWrapHandler(Wrapper wrapper) {
this.wrapper = wrapper;
}
@Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length);
byte[] wrapperBytes = wrapper.wrap(bytes, 0, bytes.length);
out.ensureWritable(4 + wrapperBytes.length);
out.writeInt(wrapperBytes.length);
out.writeBytes(wrapperBytes);

View File

@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import java.io.IOException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus;
import org.apache.hadoop.hbase.security.SaslUnwrapHandler;
import org.apache.hadoop.hbase.security.SaslWrapHandler;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBufOutputStream;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
/**
* Implement SASL negotiation logic for rpc server.
*/
class NettyHBaseSaslRpcServerHandler extends SimpleChannelInboundHandler<ByteBuf> {
private static final Logger LOG = LoggerFactory.getLogger(NettyHBaseSaslRpcServerHandler.class);
static final String DECODER_NAME = "SaslNegotiationDecoder";
private final NettyRpcServer rpcServer;
private final NettyServerRpcConnection conn;
NettyHBaseSaslRpcServerHandler(NettyRpcServer rpcServer, NettyServerRpcConnection conn) {
this.rpcServer = rpcServer;
this.conn = conn;
}
private void doResponse(ChannelHandlerContext ctx, SaslStatus status, Writable rv,
String errorClass, String error) throws IOException {
// In my testing, have noticed that sasl messages are usually
// in the ballpark of 100-200. That's why the initial capacity is 256.
ByteBuf resp = ctx.alloc().buffer(256);
try (ByteBufOutputStream out = new ByteBufOutputStream(resp)) {
out.writeInt(status.state); // write status
if (status == SaslStatus.SUCCESS) {
rv.write(out);
} else {
WritableUtils.writeString(out, errorClass);
WritableUtils.writeString(out, error);
}
}
NettyFutureUtils.safeWriteAndFlush(ctx, resp);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
LOG.debug("Read input token of size={} for processing by saslServer.evaluateResponse()",
msg.readableBytes());
HBaseSaslRpcServer saslServer = conn.getOrCreateSaslServer();
byte[] saslToken = new byte[msg.readableBytes()];
msg.readBytes(saslToken, 0, saslToken.length);
byte[] replyToken = saslServer.evaluateResponse(saslToken);
if (replyToken != null) {
LOG.debug("Will send token of size {} from saslServer.", replyToken.length);
doResponse(ctx, SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
conn.finishSaslNegotiation();
String qop = saslServer.getNegotiatedQop();
boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ChannelPipeline p = ctx.pipeline();
if (useWrap) {
p.addFirst(new SaslWrapHandler(saslServer::wrap));
p.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
new SaslUnwrapHandler(saslServer::unwrap));
}
conn.setupDecoder();
p.remove(this);
p.remove(DECODER_NAME);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOG.error("Error when doing SASL handshade, provider={}", conn.provider, cause);
Throwable sendToClient = HBaseSaslRpcServer.unwrap(cause);
doResponse(ctx, SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
conn.saslServer != null ? conn.saslServer.getAttemptingUser() : "Unknown");
NettyFutureUtils.safeClose(ctx);
}
}

View File

@ -38,21 +38,18 @@ import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
* @since 2.0.0
*/
@InterfaceAudience.Private
public class NettyRpcFrameDecoder extends ByteToMessageDecoder {
class NettyRpcFrameDecoder extends ByteToMessageDecoder {
private static int FRAME_LENGTH_FIELD_LENGTH = 4;
private final int maxFrameLength;
final NettyServerRpcConnection connection;
private boolean requestTooBig;
private String requestTooBigMessage;
public NettyRpcFrameDecoder(int maxFrameLength) {
public NettyRpcFrameDecoder(int maxFrameLength, NettyServerRpcConnection connection) {
this.maxFrameLength = maxFrameLength;
}
NettyServerRpcConnection connection;
void setConnection(NettyServerRpcConnection connection) {
this.connection = connection;
}
@ -75,10 +72,10 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder {
if (frameLength > maxFrameLength) {
requestTooBig = true;
requestTooBigMessage = "RPC data length of " + frameLength + " received from "
+ connection.getHostAddress() + " is greater than max allowed "
+ connection.rpcServer.maxRequestSize + ". Set \"" + SimpleRpcServer.MAX_REQUEST_SIZE
+ "\" on server to override this limit (not recommended)";
requestTooBigMessage =
"RPC data length of " + frameLength + " received from " + connection.getHostAddress()
+ " is greater than max allowed " + connection.rpcServer.maxRequestSize + ". Set \""
+ RpcServer.MAX_REQUEST_SIZE + "\" on server to override this limit (not recommended)";
NettyRpcServer.LOG.warn(requestTooBigMessage);
@ -132,7 +129,7 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder {
// Make sure the client recognizes the underlying exception
// Otherwise, throw a DoNotRetryIOException.
if (
VersionInfoUtil.hasMinimumVersion(connection.connectionHeader.getVersionInfo(),
VersionInfoUtil.hasMinimumVersion(connection.getVersionInfo(),
RequestTooBigException.MAJOR_VERSION, RequestTooBigException.MINOR_VERSION)
) {
reqTooBig.setResponse(null, null, reqTooBigEx, requestTooBigMessage);

View File

@ -106,11 +106,9 @@ public class NettyRpcServer extends RpcServer {
ChannelPipeline pipeline = ch.pipeline();
FixedLengthFrameDecoder preambleDecoder = new FixedLengthFrameDecoder(6);
preambleDecoder.setSingleDecode(true);
pipeline.addLast("preambleDecoder", preambleDecoder);
pipeline.addLast("preambleHandler", createNettyRpcServerPreambleHandler());
pipeline.addLast("frameDecoder", new NettyRpcFrameDecoder(maxRequestSize));
pipeline.addLast("decoder", new NettyRpcServerRequestDecoder(allChannels, metrics));
pipeline.addLast("encoder", new NettyRpcServerResponseEncoder(metrics));
pipeline.addLast(NettyRpcServerPreambleHandler.DECODER_NAME, preambleDecoder);
pipeline.addLast(createNettyRpcServerPreambleHandler(),
new NettyRpcServerResponseEncoder(metrics));
}
});
try {
@ -153,6 +151,7 @@ public class NettyRpcServer extends RpcServer {
}
}
// will be overriden in tests
@InterfaceAudience.Private
protected NettyRpcServerPreambleHandler createNettyRpcServerPreambleHandler() {
return new NettyRpcServerPreambleHandler(NettyRpcServer.this);

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.hbase.ipc;
import java.nio.ByteBuffer;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -25,6 +26,7 @@ import org.apache.hbase.thirdparty.io.netty.channel.Channel;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
/**
* Handle connection preamble.
@ -33,6 +35,8 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
@InterfaceAudience.Private
class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler<ByteBuf> {
static final String DECODER_NAME = "preambleDecoder";
private final NettyRpcServer rpcServer;
public NettyRpcServerPreambleHandler(NettyRpcServer rpcServer) {
@ -50,12 +54,29 @@ class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler<ByteBuf>
return;
}
ChannelPipeline p = ctx.pipeline();
((NettyRpcFrameDecoder) p.get("frameDecoder")).setConnection(conn);
((NettyRpcServerRequestDecoder) p.get("decoder")).setConnection(conn);
if (conn.useSasl) {
LengthFieldBasedFrameDecoder decoder =
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4);
decoder.setSingleDecode(true);
p.addLast(NettyHBaseSaslRpcServerHandler.DECODER_NAME, decoder);
p.addLast(new NettyHBaseSaslRpcServerHandler(rpcServer, conn));
} else {
conn.setupDecoder();
}
// add first and then remove, so the single decode decoder will pass the remaining bytes to the
// handler above.
p.remove(this);
p.remove("preambleDecoder");
p.remove(DECODER_NAME);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.",
ctx.channel().remoteAddress(), cause);
NettyFutureUtils.safeClose(ctx);
}
// will be overridden in tests
protected NettyServerRpcConnection createNettyServerRpcConnection(Channel channel) {
return new NettyServerRpcConnection(rpcServer, channel);
}

View File

@ -17,64 +17,42 @@
*/
package org.apache.hadoop.hbase.ipc;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.hbase.thirdparty.io.netty.channel.group.ChannelGroup;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
/**
* Decoder for rpc request.
* @since 2.0.0
*/
@InterfaceAudience.Private
class NettyRpcServerRequestDecoder extends ChannelInboundHandlerAdapter {
private final ChannelGroup allChannels;
class NettyRpcServerRequestDecoder extends SimpleChannelInboundHandler<ByteBuf> {
private final MetricsHBaseServer metrics;
public NettyRpcServerRequestDecoder(ChannelGroup allChannels, MetricsHBaseServer metrics) {
this.allChannels = allChannels;
private final NettyServerRpcConnection connection;
public NettyRpcServerRequestDecoder(MetricsHBaseServer metrics,
NettyServerRpcConnection connection) {
super(false);
this.metrics = metrics;
}
private NettyServerRpcConnection connection;
void setConnection(NettyServerRpcConnection connection) {
this.connection = connection;
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
allChannels.add(ctx.channel());
NettyRpcServer.LOG.trace("Connection {}; # active connections={}",
ctx.channel().remoteAddress(), (allChannels.size() - 1));
super.channelActive(ctx);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf input = (ByteBuf) msg;
// 4 bytes length field
metrics.receivedBytes(input.readableBytes() + 4);
connection.process(input);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
allChannels.remove(ctx.channel());
NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}",
ctx.channel().remoteAddress(), (allChannels.size() - 1));
super.channelInactive(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) {
allChannels.remove(ctx.channel());
NettyRpcServer.LOG.trace("Connection {}; caught unexpected downstream exception.",
NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.",
ctx.channel().remoteAddress(), e);
ctx.channel().close();
NettyFutureUtils.safeClose(ctx);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
// 4 bytes length field
metrics.receivedBytes(msg.readableBytes() + 4);
connection.process(msg);
}
}

View File

@ -20,12 +20,12 @@ package org.apache.hadoop.hbase.ipc;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -33,7 +33,7 @@ import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescrip
import org.apache.hbase.thirdparty.com.google.protobuf.Message;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.Channel;
import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader;
@ -49,10 +49,16 @@ class NettyServerRpcConnection extends ServerRpcConnection {
NettyServerRpcConnection(NettyRpcServer rpcServer, Channel channel) {
super(rpcServer);
this.channel = channel;
rpcServer.allChannels.add(channel);
NettyRpcServer.LOG.trace("Connection {}; # active connections={}", channel.remoteAddress(),
rpcServer.allChannels.size() - 1);
// register close hook to release resources
channel.closeFuture().addListener(f -> {
NettyFutureUtils.addListener(channel.closeFuture(), f -> {
disposeSasl();
callCleanupIfNeeded();
NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}", channel.remoteAddress(),
rpcServer.allChannels.size() - 1);
rpcServer.allChannels.remove(channel);
});
InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress());
this.addr = inetSocketAddress.getAddress();
@ -64,38 +70,22 @@ class NettyServerRpcConnection extends ServerRpcConnection {
this.remotePort = inetSocketAddress.getPort();
}
void process(final ByteBuf buf) throws IOException, InterruptedException {
if (connectionHeaderRead) {
this.callCleanup = () -> ReferenceCountUtil.safeRelease(buf);
process(new SingleByteBuff(buf.nioBuffer()));
} else {
ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes());
try {
buf.readBytes(connectionHeader);
} finally {
buf.release();
}
process(connectionHeader);
}
void setupDecoder() {
ChannelPipeline p = channel.pipeline();
p.addLast("frameDecoder", new NettyRpcFrameDecoder(rpcServer.maxRequestSize, this));
p.addLast("decoder", new NettyRpcServerRequestDecoder(rpcServer.metrics, this));
}
void process(ByteBuffer buf) throws IOException, InterruptedException {
process(new SingleByteBuff(buf));
}
void process(ByteBuff buf) throws IOException, InterruptedException {
try {
void process(ByteBuf buf) throws IOException, InterruptedException {
if (skipInitialSaslHandshake) {
skipInitialSaslHandshake = false;
callCleanupIfNeeded();
buf.release();
return;
}
if (useSasl) {
saslReadAndProcess(buf);
} else {
processOneRpc(buf);
}
this.callCleanup = () -> buf.release();
ByteBuff byteBuff = new SingleByteBuff(buf.nioBuffer());
try {
processOneRpc(byteBuff);
} catch (Exception e) {
callCleanupIfNeeded();
throw e;

View File

@ -399,37 +399,6 @@ public abstract class ServerCall<T extends ServerRpcConnection> implements RpcCa
return pbBuf;
}
protected BufferChain wrapWithSasl(BufferChain bc) throws IOException {
if (!this.connection.useSasl) {
return bc;
}
// Looks like no way around this; saslserver wants a byte array. I have to make it one.
// THIS IS A BIG UGLY COPY.
byte[] responseBytes = bc.getBytes();
byte[] token;
// synchronization may be needed since there can be multiple Handler
// threads using saslServer or Crypto AES to wrap responses.
if (connection.useCryptoAesWrap) {
// wrap with Crypto AES
synchronized (connection.cryptoAES) {
token = connection.cryptoAES.wrap(responseBytes, 0, responseBytes.length);
}
} else {
synchronized (connection.saslServer) {
token = connection.saslServer.wrap(responseBytes, 0, responseBytes.length);
}
}
if (RpcServer.LOG.isTraceEnabled()) {
RpcServer.LOG
.trace("Adding saslServer wrapped token of size " + token.length + " as call response.");
}
ByteBuffer[] responseBufs = new ByteBuffer[2];
responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length));
responseBufs[1] = ByteBuffer.wrap(token);
return new BufferChain(responseBufs);
}
@Override
public long disconnectSince() {
if (!this.connection.isConnectionOpen()) {
@ -556,20 +525,6 @@ public abstract class ServerCall<T extends ServerRpcConnection> implements RpcCa
@Override
public synchronized BufferChain getResponse() {
if (connection.useWrap) {
/*
* wrapping result with SASL as the last step just before sending it out, so every message
* must have the right increasing sequence number
*/
try {
return wrapWithSasl(response);
} catch (IOException e) {
/* it is exactly the same what setResponse() does */
RpcServer.LOG.warn("Exception while creating response " + e);
return null;
}
} else {
return response;
}
}
}

View File

@ -24,15 +24,12 @@ import io.opentelemetry.api.trace.Span;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.context.propagation.TextMapGetter;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Objects;
import java.util.Properties;
@ -47,7 +44,6 @@ import org.apache.hadoop.hbase.io.ByteBufferOutputStream;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.security.AccessDeniedException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus;
@ -58,7 +54,7 @@ import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvide
import org.apache.hadoop.hbase.security.provider.SimpleSaslServerAuthenticationProvider;
import org.apache.hadoop.hbase.trace.TraceUtil;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.hbase.util.Pair;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
@ -67,7 +63,6 @@ import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ProxyUsers;
import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -120,16 +115,9 @@ abstract class ServerRpcConnection implements Closeable {
protected BlockingService service;
protected SaslServerAuthenticationProvider provider;
protected boolean saslContextEstablished;
protected boolean skipInitialSaslHandshake;
private ByteBuffer unwrappedData;
// When is this set? FindBugs wants to know! Says NP
private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
protected boolean useSasl;
protected HBaseSaslRpcServer saslServer;
protected CryptoAES cryptoAES;
protected boolean useWrap = false;
protected boolean useCryptoAesWrap = false;
// was authentication allowed with a fallback to simple auth
protected boolean authenticatedWithFallback;
@ -164,7 +152,7 @@ abstract class ServerRpcConnection implements Closeable {
}
public VersionInfo getVersionInfo() {
if (connectionHeader.hasVersionInfo()) {
if (connectionHeader != null && connectionHeader.hasVersionInfo()) {
return connectionHeader.getVersionInfo();
}
return null;
@ -181,18 +169,24 @@ abstract class ServerRpcConnection implements Closeable {
/**
* Set up cell block codecs n
*/
private void setupCellBlockCodecs(final ConnectionHeader header) throws FatalConnectionException {
private void setupCellBlockCodecs() throws FatalConnectionException {
// TODO: Plug in other supported decoders.
if (!header.hasCellBlockCodecClass()) return;
String className = header.getCellBlockCodecClass();
if (className == null || className.length() == 0) return;
if (!connectionHeader.hasCellBlockCodecClass()) {
return;
}
String className = connectionHeader.getCellBlockCodecClass();
if (className == null || className.length() == 0) {
return;
}
try {
this.codec = (Codec) Class.forName(className).getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new UnsupportedCellCodecException(className, e);
}
if (!header.hasCellBlockCompressorClass()) return;
className = header.getCellBlockCompressorClass();
if (!connectionHeader.hasCellBlockCompressorClass()) {
return;
}
className = connectionHeader.getCellBlockCompressorClass();
try {
this.compressionCodec =
(CompressionCodec) Class.forName(className).getDeclaredConstructor().newInstance();
@ -202,21 +196,29 @@ abstract class ServerRpcConnection implements Closeable {
}
/**
* Set up cipher for rpc encryption with Apache Commons Crypto n
* Set up cipher for rpc encryption with Apache Commons Crypto.
*/
private void setupCryptoCipher(final ConnectionHeader header,
RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) throws FatalConnectionException {
private Pair<RPCProtos.ConnectionHeaderResponse, CryptoAES> setupCryptoCipher()
throws FatalConnectionException {
// If simple auth, return
if (saslServer == null) return;
if (saslServer == null) {
return null;
}
// check if rpc encryption with Crypto AES
String qop = saslServer.getNegotiatedQop();
boolean isEncryption = SaslUtil.QualityOfProtection.PRIVACY.getSaslQop().equalsIgnoreCase(qop);
boolean isCryptoAesEncryption = isEncryption
&& this.rpcServer.conf.getBoolean("hbase.rpc.crypto.encryption.aes.enabled", false);
if (!isCryptoAesEncryption) return;
if (!header.hasRpcCryptoCipherTransformation()) return;
String transformation = header.getRpcCryptoCipherTransformation();
if (transformation == null || transformation.length() == 0) return;
if (!isCryptoAesEncryption) {
return null;
}
if (!connectionHeader.hasRpcCryptoCipherTransformation()) {
return null;
}
String transformation = connectionHeader.getRpcCryptoCipherTransformation();
if (transformation == null || transformation.length() == 0) {
return null;
}
// Negotiates AES based on complete saslServer.
// The Crypto metadata need to be encrypted and send to client.
Properties properties = new Properties();
@ -242,6 +244,7 @@ abstract class ServerRpcConnection implements Closeable {
byte[] inIv = new byte[len];
byte[] outIv = new byte[len];
CryptoAES cryptoAES;
try {
// generate the cipher meta data with SecureRandom
CryptoRandom secureRandom = CryptoRandomFactory.getCryptoRandom(properties);
@ -252,6 +255,9 @@ abstract class ServerRpcConnection implements Closeable {
// create CryptoAES for server
cryptoAES = new CryptoAES(transformation, properties, inKey, outKey, inIv, outIv);
} catch (GeneralSecurityException | IOException ex) {
throw new UnsupportedCryptoException(ex.getMessage(), ex);
}
// create SaslCipherMeta and send to client,
// for client, the [inKey, outKey], [inIv, outIv] should be reversed
RPCProtos.CryptoCipherMeta.Builder ccmBuilder = RPCProtos.CryptoCipherMeta.newBuilder();
@ -260,11 +266,9 @@ abstract class ServerRpcConnection implements Closeable {
ccmBuilder.setInKey(getByteString(outKey));
ccmBuilder.setOutIv(getByteString(inIv));
ccmBuilder.setOutKey(getByteString(inKey));
chrBuilder.setCryptoCipherMeta(ccmBuilder);
useCryptoAesWrap = true;
} catch (GeneralSecurityException | IOException ex) {
throw new UnsupportedCryptoException(ex.getMessage(), ex);
}
RPCProtos.ConnectionHeaderResponse resp =
RPCProtos.ConnectionHeaderResponse.newBuilder().setCryptoCipherMeta(ccmBuilder).build();
return Pair.newPair(resp, cryptoAES);
}
private ByteString getByteString(byte[] bytes) {
@ -327,125 +331,20 @@ abstract class ServerRpcConnection implements Closeable {
doRespond(() -> bc);
}
public void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException {
if (saslContextEstablished) {
RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()",
saslToken.limit());
if (!useWrap) {
processOneRpc(saslToken);
} else {
byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes();
byte[] plaintextData;
if (useCryptoAesWrap) {
// unwrap with CryptoAES
plaintextData = cryptoAES.unwrap(b, 0, b.length);
} else {
plaintextData = saslServer.unwrap(b, 0, b.length);
}
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
processUnwrappedData(plaintextData);
}
} else {
byte[] replyToken;
try {
HBaseSaslRpcServer getOrCreateSaslServer() throws IOException {
if (saslServer == null) {
try {
saslServer =
new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager);
} catch (Exception e) {
RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer "
+ "with sasl provider: " + provider, e);
throw e;
saslServer = new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager);
}
RpcServer.LOG.debug("Created SASL server with mechanism={}",
provider.getSaslAuthMethod().getAuthMethod());
return saslServer;
}
RpcServer.LOG.debug(
"Read input token of size={} for processing by saslServer." + "evaluateResponse()",
saslToken.limit());
replyToken = saslServer
.evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes());
} catch (IOException e) {
RpcServer.LOG.debug("Failed to execute SASL handshake", e);
IOException sendToClient = e;
Throwable cause = e;
while (cause != null) {
if (cause instanceof InvalidToken) {
sendToClient = (InvalidToken) cause;
break;
}
cause = cause.getCause();
}
doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
this.rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
saslServer.getAttemptingUser());
throw e;
} finally {
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
}
if (replyToken != null) {
if (RpcServer.LOG.isDebugEnabled()) {
RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer.");
}
doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
void finishSaslNegotiation() throws IOException {
String qop = saslServer.getNegotiatedQop();
useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ugi =
provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
ugi = provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
RpcServer.LOG.debug(
"SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi,
qop);
this.rpcServer.metrics.authenticationSuccess();
"SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi, qop);
rpcServer.metrics.authenticationSuccess();
RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi);
saslContextEstablished = true;
}
}
}
private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf));
// Read all RPCs contained in the inBuf, even partial ones
while (true) {
int count;
if (unwrappedDataLengthBuffer.remaining() > 0) {
count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) {
return;
}
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == RpcClient.PING_CALL_ID) {
if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = this.rpcServer.channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0) {
return;
}
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(new SingleByteBuff(unwrappedData));
unwrappedData = null;
}
}
}
public void processOneRpc(ByteBuff buf) throws IOException, InterruptedException {
@ -453,6 +352,7 @@ abstract class ServerRpcConnection implements Closeable {
processRequest(buf);
} else {
processConnectionHeader(buf);
callCleanupIfNeeded();
this.connectionHeaderRead = true;
if (rpcServer.needAuthorization() && !authorizeConnection()) {
// Throw FatalConnectionException wrapping ACE so client does right thing and closes
@ -486,25 +386,35 @@ abstract class ServerRpcConnection implements Closeable {
return true;
}
private CodedInputStream createCis(ByteBuff buf) {
// Here we read in the header. We avoid having pb
// do its default 4k allocation for CodedInputStream. We force it to use
// backing array.
CodedInputStream cis;
if (buf.hasArray()) {
cis = UnsafeByteOperations
.unsafeWrap(buf.array(), buf.arrayOffset() + buf.position(), buf.limit()).newCodedInput();
} else {
cis = UnsafeByteOperations.unsafeWrap(new ByteBuffByteInput(buf, buf.limit()), 0, buf.limit())
.newCodedInput();
}
cis.enableAliasing(true);
return cis;
}
// Reads the connection header following version
private void processConnectionHeader(ByteBuff buf) throws IOException {
if (buf.hasArray()) {
this.connectionHeader = ConnectionHeader.parseFrom(buf.array());
} else {
CodedInputStream cis = UnsafeByteOperations
.unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput();
cis.enableAliasing(true);
this.connectionHeader = ConnectionHeader.parseFrom(cis);
}
this.connectionHeader = ConnectionHeader.parseFrom(createCis(buf));
String serviceName = connectionHeader.getServiceName();
if (serviceName == null) throw new EmptyServiceNameException();
if (serviceName == null) {
throw new EmptyServiceNameException();
}
this.service = RpcServer.getService(this.rpcServer.services, serviceName);
if (this.service == null) throw new UnknownServiceException(serviceName);
setupCellBlockCodecs(this.connectionHeader);
RPCProtos.ConnectionHeaderResponse.Builder chrBuilder =
RPCProtos.ConnectionHeaderResponse.newBuilder();
setupCryptoCipher(this.connectionHeader, chrBuilder);
responseConnectionHeader(chrBuilder);
if (this.service == null) {
throw new UnknownServiceException(serviceName);
}
setupCellBlockCodecs();
sendConnectionHeaderResponseIfNeeded();
UserGroupInformation protocolUser = createUser(connectionHeader);
if (!useSasl) {
ugi = protocolUser;
@ -553,25 +463,35 @@ abstract class ServerRpcConnection implements Closeable {
/**
* Send the response for connection header
*/
private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder)
throws FatalConnectionException {
private void sendConnectionHeaderResponseIfNeeded() throws FatalConnectionException {
Pair<RPCProtos.ConnectionHeaderResponse, CryptoAES> pair = setupCryptoCipher();
// Response the connection header if Crypto AES is enabled
if (!chrBuilder.hasCryptoCipherMeta()) return;
try {
byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray();
// encrypt the Crypto AES cipher meta data with sasl server, and send to client
byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4];
Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4);
Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length);
byte[] wrapped = saslServer.wrap(unwrapped, 0, unwrapped.length);
BufferChain bc;
try (ByteBufferOutputStream response = new ByteBufferOutputStream(wrapped.length + 4);
DataOutputStream out = new DataOutputStream(response)) {
out.writeInt(wrapped.length);
out.write(wrapped);
bc = new BufferChain(response.getByteBuffer());
if (pair == null) {
return;
}
doRespond(() -> bc);
try {
int size = pair.getFirst().getSerializedSize();
BufferChain bc;
try (ByteBufferOutputStream bbOut = new ByteBufferOutputStream(4 + size);
DataOutputStream out = new DataOutputStream(bbOut)) {
out.writeInt(size);
pair.getFirst().writeTo(out);
bc = new BufferChain(bbOut.getByteBuffer());
}
doRespond(new RpcResponse() {
@Override
public BufferChain getResponse() {
return bc;
}
@Override
public void done() {
// must switch after sending the connection header response, as the client still uses the
// original SaslClient to unwrap the data we send back
saslServer.switchToCryptoAES(pair.getSecond());
}
});
} catch (IOException ex) {
throw new UnsupportedCryptoException(ex.getMessage(), ex);
}
@ -581,7 +501,9 @@ abstract class ServerRpcConnection implements Closeable {
/**
* n * Has the request header and the request param and optionally encoded data buffer all in this
* one array. nn
* one array.
* <p/>
* Will be overridden in tests.
*/
protected void processRequest(ByteBuff buf) throws IOException, InterruptedException {
long totalRequestSize = buf.limit();
@ -589,14 +511,7 @@ abstract class ServerRpcConnection implements Closeable {
// Here we read in the header. We avoid having pb
// do its default 4k allocation for CodedInputStream. We force it to use
// backing array.
CodedInputStream cis;
if (buf.hasArray()) {
cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()).newCodedInput();
} else {
cis = UnsafeByteOperations
.unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput();
}
cis.enableAliasing(true);
CodedInputStream cis = createCis(buf);
int headerSize = cis.readRawVarint32();
offset = cis.getTotalBytesRead();
Message.Builder builder = RequestHeader.newBuilder();
@ -737,7 +652,7 @@ abstract class ServerRpcConnection implements Closeable {
}
private void doBadPreambleHandling(String msg, Exception e) throws IOException {
SimpleRpcServer.LOG.warn(msg);
RpcServer.LOG.warn(msg);
doRespond(getErrorResponse(msg, e));
}
@ -762,7 +677,7 @@ abstract class ServerRpcConnection implements Closeable {
int version = preambleBuffer.get() & 0xFF;
byte authbyte = preambleBuffer.get();
if (version != SimpleRpcServer.CURRENT_VERSION) {
if (version != RpcServer.CURRENT_VERSION) {
String msg = getFatalConnectionString(version, authbyte);
doBadPreambleHandling(msg, new WrongVersionException(msg));
return false;
@ -810,34 +725,28 @@ abstract class ServerRpcConnection implements Closeable {
private static class ByteBuffByteInput extends ByteInput {
private ByteBuff buf;
private int offset;
private int length;
ByteBuffByteInput(ByteBuff buf, int offset, int length) {
ByteBuffByteInput(ByteBuff buf, int length) {
this.buf = buf;
this.offset = offset;
this.length = length;
}
@Override
public byte read(int offset) {
return this.buf.get(getAbsoluteOffset(offset));
}
private int getAbsoluteOffset(int offset) {
return this.offset + offset;
return this.buf.get(offset);
}
@Override
public int read(int offset, byte[] out, int outOffset, int len) {
this.buf.get(getAbsoluteOffset(offset), out, outOffset, len);
this.buf.get(offset, out, outOffset, len);
return len;
}
@Override
public int read(int offset, ByteBuffer out) {
int len = out.remaining();
this.buf.get(out, getAbsoluteOffset(offset), len);
this.buf.get(out, offset, len);
return len;
}

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.hbase.ipc;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
@ -28,6 +29,8 @@ import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.hadoop.hbase.HBaseIOException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.hbase.util.Threads;
import org.apache.hadoop.util.StringUtils;
@ -217,6 +220,28 @@ class SimpleRpcServerResponder extends Thread {
}
}
private BufferChain wrapWithSasl(HBaseSaslRpcServer saslServer, BufferChain bc)
throws IOException {
// Looks like no way around this; saslserver wants a byte array. I have to make it one.
// THIS IS A BIG UGLY COPY.
byte[] responseBytes = bc.getBytes();
byte[] token;
// synchronization may be needed since there can be multiple Handler
// threads using saslServer or Crypto AES to wrap responses.
synchronized (saslServer) {
token = saslServer.wrap(responseBytes, 0, responseBytes.length);
}
if (SimpleRpcServer.LOG.isTraceEnabled()) {
SimpleRpcServer.LOG
.trace("Adding saslServer wrapped token of size " + token.length + " as call response.");
}
ByteBuffer[] responseBufs = new ByteBuffer[2];
responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length));
responseBufs[1] = ByteBuffer.wrap(token);
return new BufferChain(responseBufs);
}
/**
* Process the response for this call. You need to have the lock on
* {@link org.apache.hadoop.hbase.ipc.SimpleServerRpcConnection#responseWriteLock}
@ -226,6 +251,9 @@ class SimpleRpcServerResponder extends Thread {
throws IOException {
boolean error = true;
BufferChain buf = resp.getResponse();
if (conn.useWrap) {
buf = wrapWithSasl(conn.saslServer, buf);
}
try {
// Send as much data as we can in the non-blocking fashion
long numBytes = this.simpleRpcServer.channelWrite(conn.channel, buf);

View File

@ -17,11 +17,13 @@
*/
package org.apache.hadoop.hbase.ipc;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentLinkedDeque;
@ -34,7 +36,11 @@ import org.apache.hadoop.hbase.client.VersionInfoUtil;
import org.apache.hadoop.hbase.exceptions.RequestTooBigException;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.io.BytesWritable;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -63,6 +69,11 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
// If initial preamble with version and magic has been read or not.
private boolean connectionPreambleRead = false;
private boolean saslContextEstablished;
private ByteBuffer unwrappedData;
// When is this set? FindBugs wants to know! Says NP
private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
boolean useWrap = false;
final ConcurrentLinkedDeque<RpcResponse> responseQueue = new ConcurrentLinkedDeque<>();
final Lock responseWriteLock = new ReentrantLock();
@ -142,6 +153,110 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
}
}
private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf));
// Read all RPCs contained in the inBuf, even partial ones
while (true) {
int count;
if (unwrappedDataLengthBuffer.remaining() > 0) {
count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) {
return;
}
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == RpcClient.PING_CALL_ID) {
if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = this.rpcServer.channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0) {
return;
}
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(new SingleByteBuff(unwrappedData));
unwrappedData = null;
}
}
}
private void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException {
if (saslContextEstablished) {
RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()",
saslToken.limit());
if (!useWrap) {
processOneRpc(saslToken);
} else {
byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes();
byte[] plaintextData = saslServer.unwrap(b, 0, b.length);
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
processUnwrappedData(plaintextData);
}
} else {
byte[] replyToken;
try {
try {
getOrCreateSaslServer();
} catch (Exception e) {
RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer "
+ "with sasl provider: " + provider, e);
throw e;
}
RpcServer.LOG.debug("Created SASL server with mechanism={}",
provider.getSaslAuthMethod().getAuthMethod());
RpcServer.LOG.debug(
"Read input token of size={} for processing by saslServer." + "evaluateResponse()",
saslToken.limit());
replyToken = saslServer
.evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes());
} catch (IOException e) {
RpcServer.LOG.debug("Failed to execute SASL handshake", e);
Throwable sendToClient = HBaseSaslRpcServer.unwrap(e);
doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
this.rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
saslServer.getAttemptingUser());
throw e;
} finally {
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
}
if (replyToken != null) {
if (RpcServer.LOG.isDebugEnabled()) {
RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer.");
}
doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
String qop = saslServer.getNegotiatedQop();
useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ugi =
provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
RpcServer.LOG.debug(
"SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi,
qop);
this.rpcServer.metrics.authenticationSuccess();
RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi);
saslContextEstablished = true;
}
}
}
/**
* Read off the wire. If there is not enough data to read, update the connection state with what
* we have and returns.

View File

@ -24,6 +24,7 @@ import java.util.Map;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.hadoop.hbase.security.provider.AttemptingUserProvidingSaslServer;
import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvider;
import org.apache.hadoop.security.token.SecretManager;
@ -40,6 +41,7 @@ public class HBaseSaslRpcServer {
private final AttemptingUserProvidingSaslServer serverWithProvider;
private final SaslServer saslServer;
private CryptoAES cryptoAES;
public HBaseSaslRpcServer(SaslServerAuthenticationProvider provider,
Map<String, String> saslProps, SecretManager<TokenIdentifier> secretManager)
@ -61,17 +63,29 @@ public class HBaseSaslRpcServer {
SaslUtil.safeDispose(saslServer);
}
public void switchToCryptoAES(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
public String getAttemptingUser() {
return serverWithProvider.getAttemptingUser().map(Object::toString).orElse("Unknown");
}
public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
if (cryptoAES != null) {
return cryptoAES.wrap(buf, off, len);
} else {
return saslServer.wrap(buf, off, len);
}
}
public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
if (cryptoAES != null) {
return cryptoAES.unwrap(buf, off, len);
} else {
return saslServer.unwrap(buf, off, len);
}
}
public String getNegotiatedQop() {
return (String) saslServer.getNegotiatedProperty(Sasl.QOP);
@ -92,4 +106,18 @@ public class HBaseSaslRpcServer {
}
return tokenIdentifier;
}
/**
* Unwrap InvalidToken exception, otherwise return the one passed in.
*/
public static Throwable unwrap(Throwable e) {
Throwable cause = e;
while (cause != null) {
if (cause instanceof InvalidToken) {
return cause;
}
cause = cause.getCause();
}
return e;
}
}