HBASE-27185 Rewrite NettyRpcServer to decode rpc request with netty handler (#4624)

Signed-off-by: Xin Sun <ddupgs@gmail.com>
This commit is contained in:
Duo Zhang 2022-07-27 09:00:42 +08:00 committed by GitHub
parent ac8b3a795f
commit 0c4263a18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 511 additions and 470 deletions

View File

@ -223,9 +223,6 @@ class NettyRpcConnection extends RpcConnection {
public void operationComplete(Future<Boolean> future) throws Exception { public void operationComplete(Future<Boolean> future) throws Exception {
if (future.isSuccess()) { if (future.isSuccess()) {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
p.remove(SaslChallengeDecoder.class);
p.remove(NettyHBaseSaslRpcClientHandler.class);
// check if negotiate with server for connection header is necessary // check if negotiate with server for connection header is necessary
if (saslHandler.isNeedProcessConnectionHeader()) { if (saslHandler.isNeedProcessConnectionHeader()) {
Promise<Boolean> connectionHeaderPromise = ch.eventLoop().newPromise(); Promise<Boolean> connectionHeaderPromise = ch.eventLoop().newPromise();

View File

@ -1,47 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.Unpooled;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
/**
* Unwrap messages with Crypto AES. Should be placed after a
* io.netty.handler.codec.LengthFieldBasedFrameDecoder
*/
@InterfaceAudience.Private
public class CryptoAESUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final CryptoAES cryptoAES;
public CryptoAESUnwrapHandler(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
ctx.fireChannelRead(Unpooled.wrappedBuffer(cryptoAES.unwrap(bytes, 0, bytes.length)));
}
}

View File

@ -1,48 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
/**
* wrap messages with Crypto AES.
*/
@InterfaceAudience.Private
public class CryptoAESWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final CryptoAES cryptoAES;
public CryptoAESWrapHandler(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
@Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length);
out.ensureWritable(4 + wrapperBytes.length);
out.writeInt(wrapperBytes.length);
out.writeBytes(wrapperBytes);
}
}

View File

@ -25,7 +25,6 @@ import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise; import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
@ -92,10 +91,7 @@ public class NettyHBaseRpcConnectionHeaderHandler extends SimpleChannelInboundHa
* Remove handlers for sasl encryption and add handlers for Crypto AES encryption * Remove handlers for sasl encryption and add handlers for Crypto AES encryption
*/ */
private void setupCryptoAESHandler(ChannelPipeline p, CryptoAES cryptoAES) { private void setupCryptoAESHandler(ChannelPipeline p, CryptoAES cryptoAES) {
p.remove(SaslWrapHandler.class); p.replace(SaslWrapHandler.class, null, new SaslWrapHandler(cryptoAES::wrap));
p.remove(SaslUnwrapHandler.class); p.replace(SaslUnwrapHandler.class, null, new SaslUnwrapHandler(cryptoAES::unwrap));
String lengthDecoder = p.context(LengthFieldBasedFrameDecoder.class).name();
p.addAfter(lengthDecoder, null, new CryptoAESUnwrapHandler(cryptoAES));
p.addAfter(lengthDecoder, null, new CryptoAESWrapHandler(cryptoAES));
} }
} }

View File

@ -52,9 +52,9 @@ public class NettyHBaseSaslRpcClient extends AbstractHBaseSaslRpcClient {
return; return;
} }
// add wrap and unwrap handlers to pipeline. // add wrap and unwrap handlers to pipeline.
p.addFirst(new SaslWrapHandler(saslClient), p.addFirst(new SaslWrapHandler(saslClient::wrap),
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
new SaslUnwrapHandler(saslClient)); new SaslUnwrapHandler(saslClient::unwrap));
} }
public String getSaslQOP() { public String getSaslQOP() {

View File

@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.exceptions.ConnectionClosedException; import org.apache.hadoop.hbase.exceptions.ConnectionClosedException;
import org.apache.hadoop.hbase.ipc.FallbackDisallowedException; import org.apache.hadoop.hbase.ipc.FallbackDisallowedException;
import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider; import org.apache.hadoop.hbase.security.provider.SaslClientAuthenticationProvider;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
@ -33,6 +34,7 @@ import org.slf4j.LoggerFactory;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise; import org.apache.hbase.thirdparty.io.netty.util.concurrent.Promise;
@ -77,7 +79,7 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
private void writeResponse(ChannelHandlerContext ctx, byte[] response) { private void writeResponse(ChannelHandlerContext ctx, byte[] response) {
LOG.trace("Sending token size={} from initSASLContext.", response.length); LOG.trace("Sending token size={} from initSASLContext.", response.length);
ctx.writeAndFlush( NettyFutureUtils.safeWriteAndFlush(ctx,
ctx.alloc().buffer(4 + response.length).writeInt(response.length).writeBytes(response)); ctx.alloc().buffer(4 + response.length).writeInt(response.length).writeBytes(response));
} }
@ -90,8 +92,11 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
LOG.trace("SASL negotiation for {} is complete", provider.getSaslAuthMethod().getName()); LOG.trace("SASL negotiation for {} is complete", provider.getSaslAuthMethod().getName());
} }
ChannelPipeline p = ctx.pipeline();
saslRpcClient.setupSaslHandler(ctx.pipeline()); saslRpcClient.setupSaslHandler(ctx.pipeline());
p.remove(SaslChallengeDecoder.class);
p.remove(this);
setCryptoAESOption(); setCryptoAESOption();
saslPromise.setSuccess(true); saslPromise.setSuccess(true);
@ -110,6 +115,9 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) { public void handlerAdded(ChannelHandlerContext ctx) {
// dispose the saslRpcClient when the channel is closed, since saslRpcClient is final, it is
// safe to reference it in lambda expr.
NettyFutureUtils.addListener(ctx.channel().closeFuture(), f -> saslRpcClient.dispose());
try { try {
byte[] initialResponse = ugi.doAs(new PrivilegedExceptionAction<byte[]>() { byte[] initialResponse = ugi.doAs(new PrivilegedExceptionAction<byte[]>() {
@ -170,14 +178,12 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler<
@Override @Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception { public void channelInactive(ChannelHandlerContext ctx) throws Exception {
saslRpcClient.dispose();
saslPromise.tryFailure(new ConnectionClosedException("Connection closed")); saslPromise.tryFailure(new ConnectionClosedException("Connection closed"));
ctx.fireChannelInactive(); ctx.fireChannelInactive();
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
saslRpcClient.dispose();
saslPromise.tryFailure(cause); saslPromise.tryFailure(cause);
} }
} }

View File

@ -17,7 +17,7 @@
*/ */
package org.apache.hadoop.hbase.security; package org.apache.hadoop.hbase.security;
import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -32,22 +32,20 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
@InterfaceAudience.Private @InterfaceAudience.Private
public class SaslUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> { public class SaslUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final SaslClient saslClient; public interface Unwrapper {
byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException;
public SaslUnwrapHandler(SaslClient saslClient) {
this.saslClient = saslClient;
} }
@Override private final Unwrapper unwrapper;
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
SaslUtil.safeDispose(saslClient); public SaslUnwrapHandler(Unwrapper unwrapper) {
ctx.fireChannelInactive(); this.unwrapper = unwrapper;
} }
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] bytes = new byte[msg.readableBytes()]; byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes); msg.readBytes(bytes);
ctx.fireChannelRead(Unpooled.wrappedBuffer(saslClient.unwrap(bytes, 0, bytes.length))); ctx.fireChannelRead(Unpooled.wrappedBuffer(unwrapper.unwrap(bytes, 0, bytes.length)));
} }
} }

View File

@ -17,7 +17,7 @@
*/ */
package org.apache.hadoop.hbase.security; package org.apache.hadoop.hbase.security;
import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -30,17 +30,21 @@ import org.apache.hbase.thirdparty.io.netty.handler.codec.MessageToByteEncoder;
@InterfaceAudience.Private @InterfaceAudience.Private
public class SaslWrapHandler extends MessageToByteEncoder<ByteBuf> { public class SaslWrapHandler extends MessageToByteEncoder<ByteBuf> {
private final SaslClient saslClient; public interface Wrapper {
byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException;
}
public SaslWrapHandler(SaslClient saslClient) { private final Wrapper wrapper;
this.saslClient = saslClient;
public SaslWrapHandler(Wrapper wrapper) {
this.wrapper = wrapper;
} }
@Override @Override
protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception {
byte[] bytes = new byte[msg.readableBytes()]; byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes); msg.readBytes(bytes);
byte[] wrapperBytes = saslClient.wrap(bytes, 0, bytes.length); byte[] wrapperBytes = wrapper.wrap(bytes, 0, bytes.length);
out.ensureWritable(4 + wrapperBytes.length); out.ensureWritable(4 + wrapperBytes.length);
out.writeInt(wrapperBytes.length); out.writeInt(wrapperBytes.length);
out.writeBytes(wrapperBytes); out.writeBytes(wrapperBytes);

View File

@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import java.io.IOException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus;
import org.apache.hadoop.hbase.security.SaslUnwrapHandler;
import org.apache.hadoop.hbase.security.SaslWrapHandler;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBufOutputStream;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
/**
* Implement SASL negotiation logic for rpc server.
*/
class NettyHBaseSaslRpcServerHandler extends SimpleChannelInboundHandler<ByteBuf> {
private static final Logger LOG = LoggerFactory.getLogger(NettyHBaseSaslRpcServerHandler.class);
static final String DECODER_NAME = "SaslNegotiationDecoder";
private final NettyRpcServer rpcServer;
private final NettyServerRpcConnection conn;
NettyHBaseSaslRpcServerHandler(NettyRpcServer rpcServer, NettyServerRpcConnection conn) {
this.rpcServer = rpcServer;
this.conn = conn;
}
private void doResponse(ChannelHandlerContext ctx, SaslStatus status, Writable rv,
String errorClass, String error) throws IOException {
// In my testing, have noticed that sasl messages are usually
// in the ballpark of 100-200. That's why the initial capacity is 256.
ByteBuf resp = ctx.alloc().buffer(256);
try (ByteBufOutputStream out = new ByteBufOutputStream(resp)) {
out.writeInt(status.state); // write status
if (status == SaslStatus.SUCCESS) {
rv.write(out);
} else {
WritableUtils.writeString(out, errorClass);
WritableUtils.writeString(out, error);
}
}
NettyFutureUtils.safeWriteAndFlush(ctx, resp);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
LOG.debug("Read input token of size={} for processing by saslServer.evaluateResponse()",
msg.readableBytes());
HBaseSaslRpcServer saslServer = conn.getOrCreateSaslServer();
byte[] saslToken = new byte[msg.readableBytes()];
msg.readBytes(saslToken, 0, saslToken.length);
byte[] replyToken = saslServer.evaluateResponse(saslToken);
if (replyToken != null) {
LOG.debug("Will send token of size {} from saslServer.", replyToken.length);
doResponse(ctx, SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
conn.finishSaslNegotiation();
String qop = saslServer.getNegotiatedQop();
boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ChannelPipeline p = ctx.pipeline();
if (useWrap) {
p.addFirst(new SaslWrapHandler(saslServer::wrap));
p.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
new SaslUnwrapHandler(saslServer::unwrap));
}
conn.setupDecoder();
p.remove(this);
p.remove(DECODER_NAME);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOG.error("Error when doing SASL handshade, provider={}", conn.provider, cause);
Throwable sendToClient = HBaseSaslRpcServer.unwrap(cause);
doResponse(ctx, SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
conn.saslServer != null ? conn.saslServer.getAttemptingUser() : "Unknown");
NettyFutureUtils.safeClose(ctx);
}
}

View File

@ -38,21 +38,18 @@ import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos;
* @since 2.0.0 * @since 2.0.0
*/ */
@InterfaceAudience.Private @InterfaceAudience.Private
public class NettyRpcFrameDecoder extends ByteToMessageDecoder { class NettyRpcFrameDecoder extends ByteToMessageDecoder {
private static int FRAME_LENGTH_FIELD_LENGTH = 4; private static int FRAME_LENGTH_FIELD_LENGTH = 4;
private final int maxFrameLength; private final int maxFrameLength;
final NettyServerRpcConnection connection;
private boolean requestTooBig; private boolean requestTooBig;
private String requestTooBigMessage; private String requestTooBigMessage;
public NettyRpcFrameDecoder(int maxFrameLength) { public NettyRpcFrameDecoder(int maxFrameLength, NettyServerRpcConnection connection) {
this.maxFrameLength = maxFrameLength; this.maxFrameLength = maxFrameLength;
}
NettyServerRpcConnection connection;
void setConnection(NettyServerRpcConnection connection) {
this.connection = connection; this.connection = connection;
} }
@ -75,10 +72,10 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder {
if (frameLength > maxFrameLength) { if (frameLength > maxFrameLength) {
requestTooBig = true; requestTooBig = true;
requestTooBigMessage = "RPC data length of " + frameLength + " received from " requestTooBigMessage =
+ connection.getHostAddress() + " is greater than max allowed " "RPC data length of " + frameLength + " received from " + connection.getHostAddress()
+ connection.rpcServer.maxRequestSize + ". Set \"" + SimpleRpcServer.MAX_REQUEST_SIZE + " is greater than max allowed " + connection.rpcServer.maxRequestSize + ". Set \""
+ "\" on server to override this limit (not recommended)"; + RpcServer.MAX_REQUEST_SIZE + "\" on server to override this limit (not recommended)";
NettyRpcServer.LOG.warn(requestTooBigMessage); NettyRpcServer.LOG.warn(requestTooBigMessage);
@ -132,7 +129,7 @@ public class NettyRpcFrameDecoder extends ByteToMessageDecoder {
// Make sure the client recognizes the underlying exception // Make sure the client recognizes the underlying exception
// Otherwise, throw a DoNotRetryIOException. // Otherwise, throw a DoNotRetryIOException.
if ( if (
VersionInfoUtil.hasMinimumVersion(connection.connectionHeader.getVersionInfo(), VersionInfoUtil.hasMinimumVersion(connection.getVersionInfo(),
RequestTooBigException.MAJOR_VERSION, RequestTooBigException.MINOR_VERSION) RequestTooBigException.MAJOR_VERSION, RequestTooBigException.MINOR_VERSION)
) { ) {
reqTooBig.setResponse(null, null, reqTooBigEx, requestTooBigMessage); reqTooBig.setResponse(null, null, reqTooBigEx, requestTooBigMessage);

View File

@ -106,11 +106,9 @@ public class NettyRpcServer extends RpcServer {
ChannelPipeline pipeline = ch.pipeline(); ChannelPipeline pipeline = ch.pipeline();
FixedLengthFrameDecoder preambleDecoder = new FixedLengthFrameDecoder(6); FixedLengthFrameDecoder preambleDecoder = new FixedLengthFrameDecoder(6);
preambleDecoder.setSingleDecode(true); preambleDecoder.setSingleDecode(true);
pipeline.addLast("preambleDecoder", preambleDecoder); pipeline.addLast(NettyRpcServerPreambleHandler.DECODER_NAME, preambleDecoder);
pipeline.addLast("preambleHandler", createNettyRpcServerPreambleHandler()); pipeline.addLast(createNettyRpcServerPreambleHandler(),
pipeline.addLast("frameDecoder", new NettyRpcFrameDecoder(maxRequestSize)); new NettyRpcServerResponseEncoder(metrics));
pipeline.addLast("decoder", new NettyRpcServerRequestDecoder(allChannels, metrics));
pipeline.addLast("encoder", new NettyRpcServerResponseEncoder(metrics));
} }
}); });
try { try {
@ -153,6 +151,7 @@ public class NettyRpcServer extends RpcServer {
} }
} }
// will be overriden in tests
@InterfaceAudience.Private @InterfaceAudience.Private
protected NettyRpcServerPreambleHandler createNettyRpcServerPreambleHandler() { protected NettyRpcServerPreambleHandler createNettyRpcServerPreambleHandler() {
return new NettyRpcServerPreambleHandler(NettyRpcServer.this); return new NettyRpcServerPreambleHandler(NettyRpcServer.this);

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.hbase.ipc; package org.apache.hadoop.hbase.ipc;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
@ -25,6 +26,7 @@ import org.apache.hbase.thirdparty.io.netty.channel.Channel;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline; import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
/** /**
* Handle connection preamble. * Handle connection preamble.
@ -33,6 +35,8 @@ import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
@InterfaceAudience.Private @InterfaceAudience.Private
class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler<ByteBuf> { class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler<ByteBuf> {
static final String DECODER_NAME = "preambleDecoder";
private final NettyRpcServer rpcServer; private final NettyRpcServer rpcServer;
public NettyRpcServerPreambleHandler(NettyRpcServer rpcServer) { public NettyRpcServerPreambleHandler(NettyRpcServer rpcServer) {
@ -50,12 +54,29 @@ class NettyRpcServerPreambleHandler extends SimpleChannelInboundHandler<ByteBuf>
return; return;
} }
ChannelPipeline p = ctx.pipeline(); ChannelPipeline p = ctx.pipeline();
((NettyRpcFrameDecoder) p.get("frameDecoder")).setConnection(conn); if (conn.useSasl) {
((NettyRpcServerRequestDecoder) p.get("decoder")).setConnection(conn); LengthFieldBasedFrameDecoder decoder =
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4);
decoder.setSingleDecode(true);
p.addLast(NettyHBaseSaslRpcServerHandler.DECODER_NAME, decoder);
p.addLast(new NettyHBaseSaslRpcServerHandler(rpcServer, conn));
} else {
conn.setupDecoder();
}
// add first and then remove, so the single decode decoder will pass the remaining bytes to the
// handler above.
p.remove(this); p.remove(this);
p.remove("preambleDecoder"); p.remove(DECODER_NAME);
} }
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.",
ctx.channel().remoteAddress(), cause);
NettyFutureUtils.safeClose(ctx);
}
// will be overridden in tests
protected NettyServerRpcConnection createNettyServerRpcConnection(Channel channel) { protected NettyServerRpcConnection createNettyServerRpcConnection(Channel channel) {
return new NettyServerRpcConnection(rpcServer, channel); return new NettyServerRpcConnection(rpcServer, channel);
} }

View File

@ -17,64 +17,42 @@
*/ */
package org.apache.hadoop.hbase.ipc; package org.apache.hadoop.hbase.ipc;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext; import org.apache.hbase.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.hbase.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter; import org.apache.hbase.thirdparty.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.hbase.thirdparty.io.netty.channel.group.ChannelGroup;
/** /**
* Decoder for rpc request. * Decoder for rpc request.
* @since 2.0.0 * @since 2.0.0
*/ */
@InterfaceAudience.Private @InterfaceAudience.Private
class NettyRpcServerRequestDecoder extends ChannelInboundHandlerAdapter { class NettyRpcServerRequestDecoder extends SimpleChannelInboundHandler<ByteBuf> {
private final ChannelGroup allChannels;
private final MetricsHBaseServer metrics; private final MetricsHBaseServer metrics;
public NettyRpcServerRequestDecoder(ChannelGroup allChannels, MetricsHBaseServer metrics) { private final NettyServerRpcConnection connection;
this.allChannels = allChannels;
public NettyRpcServerRequestDecoder(MetricsHBaseServer metrics,
NettyServerRpcConnection connection) {
super(false);
this.metrics = metrics; this.metrics = metrics;
}
private NettyServerRpcConnection connection;
void setConnection(NettyServerRpcConnection connection) {
this.connection = connection; this.connection = connection;
} }
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
allChannels.add(ctx.channel());
NettyRpcServer.LOG.trace("Connection {}; # active connections={}",
ctx.channel().remoteAddress(), (allChannels.size() - 1));
super.channelActive(ctx);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf input = (ByteBuf) msg;
// 4 bytes length field
metrics.receivedBytes(input.readableBytes() + 4);
connection.process(input);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
allChannels.remove(ctx.channel());
NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}",
ctx.channel().remoteAddress(), (allChannels.size() - 1));
super.channelInactive(ctx);
}
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) {
allChannels.remove(ctx.channel()); NettyRpcServer.LOG.warn("Connection {}; caught unexpected downstream exception.",
NettyRpcServer.LOG.trace("Connection {}; caught unexpected downstream exception.",
ctx.channel().remoteAddress(), e); ctx.channel().remoteAddress(), e);
ctx.channel().close(); NettyFutureUtils.safeClose(ctx);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
// 4 bytes length field
metrics.receivedBytes(msg.readableBytes() + 4);
connection.process(msg);
} }
} }

View File

@ -20,12 +20,12 @@ package org.apache.hadoop.hbase.ipc;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff; import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff; import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.hbase.util.NettyFutureUtils;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -33,7 +33,7 @@ import org.apache.hbase.thirdparty.com.google.protobuf.Descriptors.MethodDescrip
import org.apache.hbase.thirdparty.com.google.protobuf.Message; import org.apache.hbase.thirdparty.com.google.protobuf.Message;
import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf; import org.apache.hbase.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.hbase.thirdparty.io.netty.channel.Channel; import org.apache.hbase.thirdparty.io.netty.channel.Channel;
import org.apache.hbase.thirdparty.io.netty.util.ReferenceCountUtil; import org.apache.hbase.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader;
@ -49,10 +49,16 @@ class NettyServerRpcConnection extends ServerRpcConnection {
NettyServerRpcConnection(NettyRpcServer rpcServer, Channel channel) { NettyServerRpcConnection(NettyRpcServer rpcServer, Channel channel) {
super(rpcServer); super(rpcServer);
this.channel = channel; this.channel = channel;
rpcServer.allChannels.add(channel);
NettyRpcServer.LOG.trace("Connection {}; # active connections={}", channel.remoteAddress(),
rpcServer.allChannels.size() - 1);
// register close hook to release resources // register close hook to release resources
channel.closeFuture().addListener(f -> { NettyFutureUtils.addListener(channel.closeFuture(), f -> {
disposeSasl(); disposeSasl();
callCleanupIfNeeded(); callCleanupIfNeeded();
NettyRpcServer.LOG.trace("Disconnection {}; # active connections={}", channel.remoteAddress(),
rpcServer.allChannels.size() - 1);
rpcServer.allChannels.remove(channel);
}); });
InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress()); InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress());
this.addr = inetSocketAddress.getAddress(); this.addr = inetSocketAddress.getAddress();
@ -64,38 +70,22 @@ class NettyServerRpcConnection extends ServerRpcConnection {
this.remotePort = inetSocketAddress.getPort(); this.remotePort = inetSocketAddress.getPort();
} }
void process(final ByteBuf buf) throws IOException, InterruptedException { void setupDecoder() {
if (connectionHeaderRead) { ChannelPipeline p = channel.pipeline();
this.callCleanup = () -> ReferenceCountUtil.safeRelease(buf); p.addLast("frameDecoder", new NettyRpcFrameDecoder(rpcServer.maxRequestSize, this));
process(new SingleByteBuff(buf.nioBuffer())); p.addLast("decoder", new NettyRpcServerRequestDecoder(rpcServer.metrics, this));
} else { }
ByteBuffer connectionHeader = ByteBuffer.allocate(buf.readableBytes());
try { void process(ByteBuf buf) throws IOException, InterruptedException {
buf.readBytes(connectionHeader); if (skipInitialSaslHandshake) {
} finally { skipInitialSaslHandshake = false;
buf.release(); buf.release();
} return;
process(connectionHeader);
} }
} this.callCleanup = () -> buf.release();
ByteBuff byteBuff = new SingleByteBuff(buf.nioBuffer());
void process(ByteBuffer buf) throws IOException, InterruptedException {
process(new SingleByteBuff(buf));
}
void process(ByteBuff buf) throws IOException, InterruptedException {
try { try {
if (skipInitialSaslHandshake) { processOneRpc(byteBuff);
skipInitialSaslHandshake = false;
callCleanupIfNeeded();
return;
}
if (useSasl) {
saslReadAndProcess(buf);
} else {
processOneRpc(buf);
}
} catch (Exception e) { } catch (Exception e) {
callCleanupIfNeeded(); callCleanupIfNeeded();
throw e; throw e;

View File

@ -399,37 +399,6 @@ public abstract class ServerCall<T extends ServerRpcConnection> implements RpcCa
return pbBuf; return pbBuf;
} }
protected BufferChain wrapWithSasl(BufferChain bc) throws IOException {
if (!this.connection.useSasl) {
return bc;
}
// Looks like no way around this; saslserver wants a byte array. I have to make it one.
// THIS IS A BIG UGLY COPY.
byte[] responseBytes = bc.getBytes();
byte[] token;
// synchronization may be needed since there can be multiple Handler
// threads using saslServer or Crypto AES to wrap responses.
if (connection.useCryptoAesWrap) {
// wrap with Crypto AES
synchronized (connection.cryptoAES) {
token = connection.cryptoAES.wrap(responseBytes, 0, responseBytes.length);
}
} else {
synchronized (connection.saslServer) {
token = connection.saslServer.wrap(responseBytes, 0, responseBytes.length);
}
}
if (RpcServer.LOG.isTraceEnabled()) {
RpcServer.LOG
.trace("Adding saslServer wrapped token of size " + token.length + " as call response.");
}
ByteBuffer[] responseBufs = new ByteBuffer[2];
responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length));
responseBufs[1] = ByteBuffer.wrap(token);
return new BufferChain(responseBufs);
}
@Override @Override
public long disconnectSince() { public long disconnectSince() {
if (!this.connection.isConnectionOpen()) { if (!this.connection.isConnectionOpen()) {
@ -556,20 +525,6 @@ public abstract class ServerCall<T extends ServerRpcConnection> implements RpcCa
@Override @Override
public synchronized BufferChain getResponse() { public synchronized BufferChain getResponse() {
if (connection.useWrap) { return response;
/*
* wrapping result with SASL as the last step just before sending it out, so every message
* must have the right increasing sequence number
*/
try {
return wrapWithSasl(response);
} catch (IOException e) {
/* it is exactly the same what setResponse() does */
RpcServer.LOG.warn("Exception while creating response " + e);
return null;
}
} else {
return response;
}
} }
} }

View File

@ -24,15 +24,12 @@ import io.opentelemetry.api.trace.Span;
import io.opentelemetry.context.Context; import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope; import io.opentelemetry.context.Scope;
import io.opentelemetry.context.propagation.TextMapGetter; import io.opentelemetry.context.propagation.TextMapGetter;
import java.io.ByteArrayInputStream;
import java.io.Closeable; import java.io.Closeable;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.Objects; import java.util.Objects;
import java.util.Properties; import java.util.Properties;
@ -47,7 +44,6 @@ import org.apache.hadoop.hbase.io.ByteBufferOutputStream;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff; import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.security.AccessDeniedException; import org.apache.hadoop.hbase.security.AccessDeniedException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus; import org.apache.hadoop.hbase.security.SaslStatus;
@ -58,7 +54,7 @@ import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvide
import org.apache.hadoop.hbase.security.provider.SimpleSaslServerAuthenticationProvider; import org.apache.hadoop.hbase.security.provider.SimpleSaslServerAuthenticationProvider;
import org.apache.hadoop.hbase.trace.TraceUtil; import org.apache.hadoop.hbase.trace.TraceUtil;
import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.hbase.util.Pair;
import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.WritableUtils;
@ -67,7 +63,6 @@ import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ProxyUsers; import org.apache.hadoop.security.authorize.ProxyUsers;
import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -120,16 +115,9 @@ abstract class ServerRpcConnection implements Closeable {
protected BlockingService service; protected BlockingService service;
protected SaslServerAuthenticationProvider provider; protected SaslServerAuthenticationProvider provider;
protected boolean saslContextEstablished;
protected boolean skipInitialSaslHandshake; protected boolean skipInitialSaslHandshake;
private ByteBuffer unwrappedData;
// When is this set? FindBugs wants to know! Says NP
private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
protected boolean useSasl; protected boolean useSasl;
protected HBaseSaslRpcServer saslServer; protected HBaseSaslRpcServer saslServer;
protected CryptoAES cryptoAES;
protected boolean useWrap = false;
protected boolean useCryptoAesWrap = false;
// was authentication allowed with a fallback to simple auth // was authentication allowed with a fallback to simple auth
protected boolean authenticatedWithFallback; protected boolean authenticatedWithFallback;
@ -164,7 +152,7 @@ abstract class ServerRpcConnection implements Closeable {
} }
public VersionInfo getVersionInfo() { public VersionInfo getVersionInfo() {
if (connectionHeader.hasVersionInfo()) { if (connectionHeader != null && connectionHeader.hasVersionInfo()) {
return connectionHeader.getVersionInfo(); return connectionHeader.getVersionInfo();
} }
return null; return null;
@ -181,18 +169,24 @@ abstract class ServerRpcConnection implements Closeable {
/** /**
* Set up cell block codecs n * Set up cell block codecs n
*/ */
private void setupCellBlockCodecs(final ConnectionHeader header) throws FatalConnectionException { private void setupCellBlockCodecs() throws FatalConnectionException {
// TODO: Plug in other supported decoders. // TODO: Plug in other supported decoders.
if (!header.hasCellBlockCodecClass()) return; if (!connectionHeader.hasCellBlockCodecClass()) {
String className = header.getCellBlockCodecClass(); return;
if (className == null || className.length() == 0) return; }
String className = connectionHeader.getCellBlockCodecClass();
if (className == null || className.length() == 0) {
return;
}
try { try {
this.codec = (Codec) Class.forName(className).getDeclaredConstructor().newInstance(); this.codec = (Codec) Class.forName(className).getDeclaredConstructor().newInstance();
} catch (Exception e) { } catch (Exception e) {
throw new UnsupportedCellCodecException(className, e); throw new UnsupportedCellCodecException(className, e);
} }
if (!header.hasCellBlockCompressorClass()) return; if (!connectionHeader.hasCellBlockCompressorClass()) {
className = header.getCellBlockCompressorClass(); return;
}
className = connectionHeader.getCellBlockCompressorClass();
try { try {
this.compressionCodec = this.compressionCodec =
(CompressionCodec) Class.forName(className).getDeclaredConstructor().newInstance(); (CompressionCodec) Class.forName(className).getDeclaredConstructor().newInstance();
@ -202,21 +196,29 @@ abstract class ServerRpcConnection implements Closeable {
} }
/** /**
* Set up cipher for rpc encryption with Apache Commons Crypto n * Set up cipher for rpc encryption with Apache Commons Crypto.
*/ */
private void setupCryptoCipher(final ConnectionHeader header, private Pair<RPCProtos.ConnectionHeaderResponse, CryptoAES> setupCryptoCipher()
RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) throws FatalConnectionException { throws FatalConnectionException {
// If simple auth, return // If simple auth, return
if (saslServer == null) return; if (saslServer == null) {
return null;
}
// check if rpc encryption with Crypto AES // check if rpc encryption with Crypto AES
String qop = saslServer.getNegotiatedQop(); String qop = saslServer.getNegotiatedQop();
boolean isEncryption = SaslUtil.QualityOfProtection.PRIVACY.getSaslQop().equalsIgnoreCase(qop); boolean isEncryption = SaslUtil.QualityOfProtection.PRIVACY.getSaslQop().equalsIgnoreCase(qop);
boolean isCryptoAesEncryption = isEncryption boolean isCryptoAesEncryption = isEncryption
&& this.rpcServer.conf.getBoolean("hbase.rpc.crypto.encryption.aes.enabled", false); && this.rpcServer.conf.getBoolean("hbase.rpc.crypto.encryption.aes.enabled", false);
if (!isCryptoAesEncryption) return; if (!isCryptoAesEncryption) {
if (!header.hasRpcCryptoCipherTransformation()) return; return null;
String transformation = header.getRpcCryptoCipherTransformation(); }
if (transformation == null || transformation.length() == 0) return; if (!connectionHeader.hasRpcCryptoCipherTransformation()) {
return null;
}
String transformation = connectionHeader.getRpcCryptoCipherTransformation();
if (transformation == null || transformation.length() == 0) {
return null;
}
// Negotiates AES based on complete saslServer. // Negotiates AES based on complete saslServer.
// The Crypto metadata need to be encrypted and send to client. // The Crypto metadata need to be encrypted and send to client.
Properties properties = new Properties(); Properties properties = new Properties();
@ -242,6 +244,7 @@ abstract class ServerRpcConnection implements Closeable {
byte[] inIv = new byte[len]; byte[] inIv = new byte[len];
byte[] outIv = new byte[len]; byte[] outIv = new byte[len];
CryptoAES cryptoAES;
try { try {
// generate the cipher meta data with SecureRandom // generate the cipher meta data with SecureRandom
CryptoRandom secureRandom = CryptoRandomFactory.getCryptoRandom(properties); CryptoRandom secureRandom = CryptoRandomFactory.getCryptoRandom(properties);
@ -252,19 +255,20 @@ abstract class ServerRpcConnection implements Closeable {
// create CryptoAES for server // create CryptoAES for server
cryptoAES = new CryptoAES(transformation, properties, inKey, outKey, inIv, outIv); cryptoAES = new CryptoAES(transformation, properties, inKey, outKey, inIv, outIv);
// create SaslCipherMeta and send to client,
// for client, the [inKey, outKey], [inIv, outIv] should be reversed
RPCProtos.CryptoCipherMeta.Builder ccmBuilder = RPCProtos.CryptoCipherMeta.newBuilder();
ccmBuilder.setTransformation(transformation);
ccmBuilder.setInIv(getByteString(outIv));
ccmBuilder.setInKey(getByteString(outKey));
ccmBuilder.setOutIv(getByteString(inIv));
ccmBuilder.setOutKey(getByteString(inKey));
chrBuilder.setCryptoCipherMeta(ccmBuilder);
useCryptoAesWrap = true;
} catch (GeneralSecurityException | IOException ex) { } catch (GeneralSecurityException | IOException ex) {
throw new UnsupportedCryptoException(ex.getMessage(), ex); throw new UnsupportedCryptoException(ex.getMessage(), ex);
} }
// create SaslCipherMeta and send to client,
// for client, the [inKey, outKey], [inIv, outIv] should be reversed
RPCProtos.CryptoCipherMeta.Builder ccmBuilder = RPCProtos.CryptoCipherMeta.newBuilder();
ccmBuilder.setTransformation(transformation);
ccmBuilder.setInIv(getByteString(outIv));
ccmBuilder.setInKey(getByteString(outKey));
ccmBuilder.setOutIv(getByteString(inIv));
ccmBuilder.setOutKey(getByteString(inKey));
RPCProtos.ConnectionHeaderResponse resp =
RPCProtos.ConnectionHeaderResponse.newBuilder().setCryptoCipherMeta(ccmBuilder).build();
return Pair.newPair(resp, cryptoAES);
} }
private ByteString getByteString(byte[] bytes) { private ByteString getByteString(byte[] bytes) {
@ -327,125 +331,20 @@ abstract class ServerRpcConnection implements Closeable {
doRespond(() -> bc); doRespond(() -> bc);
} }
public void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException { HBaseSaslRpcServer getOrCreateSaslServer() throws IOException {
if (saslContextEstablished) { if (saslServer == null) {
RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()", saslServer = new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager);
saslToken.limit());
if (!useWrap) {
processOneRpc(saslToken);
} else {
byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes();
byte[] plaintextData;
if (useCryptoAesWrap) {
// unwrap with CryptoAES
plaintextData = cryptoAES.unwrap(b, 0, b.length);
} else {
plaintextData = saslServer.unwrap(b, 0, b.length);
}
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
processUnwrappedData(plaintextData);
}
} else {
byte[] replyToken;
try {
if (saslServer == null) {
try {
saslServer =
new HBaseSaslRpcServer(provider, rpcServer.saslProps, rpcServer.secretManager);
} catch (Exception e) {
RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer "
+ "with sasl provider: " + provider, e);
throw e;
}
RpcServer.LOG.debug("Created SASL server with mechanism={}",
provider.getSaslAuthMethod().getAuthMethod());
}
RpcServer.LOG.debug(
"Read input token of size={} for processing by saslServer." + "evaluateResponse()",
saslToken.limit());
replyToken = saslServer
.evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes());
} catch (IOException e) {
RpcServer.LOG.debug("Failed to execute SASL handshake", e);
IOException sendToClient = e;
Throwable cause = e;
while (cause != null) {
if (cause instanceof InvalidToken) {
sendToClient = (InvalidToken) cause;
break;
}
cause = cause.getCause();
}
doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
this.rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
saslServer.getAttemptingUser());
throw e;
} finally {
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
}
if (replyToken != null) {
if (RpcServer.LOG.isDebugEnabled()) {
RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer.");
}
doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
String qop = saslServer.getNegotiatedQop();
useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ugi =
provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
RpcServer.LOG.debug(
"SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi,
qop);
this.rpcServer.metrics.authenticationSuccess();
RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi);
saslContextEstablished = true;
}
} }
return saslServer;
} }
private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException { void finishSaslNegotiation() throws IOException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); String qop = saslServer.getNegotiatedQop();
// Read all RPCs contained in the inBuf, even partial ones ugi = provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
while (true) { RpcServer.LOG.debug(
int count; "SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi, qop);
if (unwrappedDataLengthBuffer.remaining() > 0) { rpcServer.metrics.authenticationSuccess();
count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer); RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) {
return;
}
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == RpcClient.PING_CALL_ID) {
if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = this.rpcServer.channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0) {
return;
}
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(new SingleByteBuff(unwrappedData));
unwrappedData = null;
}
}
} }
public void processOneRpc(ByteBuff buf) throws IOException, InterruptedException { public void processOneRpc(ByteBuff buf) throws IOException, InterruptedException {
@ -453,6 +352,7 @@ abstract class ServerRpcConnection implements Closeable {
processRequest(buf); processRequest(buf);
} else { } else {
processConnectionHeader(buf); processConnectionHeader(buf);
callCleanupIfNeeded();
this.connectionHeaderRead = true; this.connectionHeaderRead = true;
if (rpcServer.needAuthorization() && !authorizeConnection()) { if (rpcServer.needAuthorization() && !authorizeConnection()) {
// Throw FatalConnectionException wrapping ACE so client does right thing and closes // Throw FatalConnectionException wrapping ACE so client does right thing and closes
@ -486,25 +386,35 @@ abstract class ServerRpcConnection implements Closeable {
return true; return true;
} }
private CodedInputStream createCis(ByteBuff buf) {
// Here we read in the header. We avoid having pb
// do its default 4k allocation for CodedInputStream. We force it to use
// backing array.
CodedInputStream cis;
if (buf.hasArray()) {
cis = UnsafeByteOperations
.unsafeWrap(buf.array(), buf.arrayOffset() + buf.position(), buf.limit()).newCodedInput();
} else {
cis = UnsafeByteOperations.unsafeWrap(new ByteBuffByteInput(buf, buf.limit()), 0, buf.limit())
.newCodedInput();
}
cis.enableAliasing(true);
return cis;
}
// Reads the connection header following version // Reads the connection header following version
private void processConnectionHeader(ByteBuff buf) throws IOException { private void processConnectionHeader(ByteBuff buf) throws IOException {
if (buf.hasArray()) { this.connectionHeader = ConnectionHeader.parseFrom(createCis(buf));
this.connectionHeader = ConnectionHeader.parseFrom(buf.array());
} else {
CodedInputStream cis = UnsafeByteOperations
.unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput();
cis.enableAliasing(true);
this.connectionHeader = ConnectionHeader.parseFrom(cis);
}
String serviceName = connectionHeader.getServiceName(); String serviceName = connectionHeader.getServiceName();
if (serviceName == null) throw new EmptyServiceNameException(); if (serviceName == null) {
throw new EmptyServiceNameException();
}
this.service = RpcServer.getService(this.rpcServer.services, serviceName); this.service = RpcServer.getService(this.rpcServer.services, serviceName);
if (this.service == null) throw new UnknownServiceException(serviceName); if (this.service == null) {
setupCellBlockCodecs(this.connectionHeader); throw new UnknownServiceException(serviceName);
RPCProtos.ConnectionHeaderResponse.Builder chrBuilder = }
RPCProtos.ConnectionHeaderResponse.newBuilder(); setupCellBlockCodecs();
setupCryptoCipher(this.connectionHeader, chrBuilder); sendConnectionHeaderResponseIfNeeded();
responseConnectionHeader(chrBuilder);
UserGroupInformation protocolUser = createUser(connectionHeader); UserGroupInformation protocolUser = createUser(connectionHeader);
if (!useSasl) { if (!useSasl) {
ugi = protocolUser; ugi = protocolUser;
@ -553,25 +463,35 @@ abstract class ServerRpcConnection implements Closeable {
/** /**
* Send the response for connection header * Send the response for connection header
*/ */
private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) private void sendConnectionHeaderResponseIfNeeded() throws FatalConnectionException {
throws FatalConnectionException { Pair<RPCProtos.ConnectionHeaderResponse, CryptoAES> pair = setupCryptoCipher();
// Response the connection header if Crypto AES is enabled // Response the connection header if Crypto AES is enabled
if (!chrBuilder.hasCryptoCipherMeta()) return; if (pair == null) {
return;
}
try { try {
byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray(); int size = pair.getFirst().getSerializedSize();
// encrypt the Crypto AES cipher meta data with sasl server, and send to client
byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4];
Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4);
Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length);
byte[] wrapped = saslServer.wrap(unwrapped, 0, unwrapped.length);
BufferChain bc; BufferChain bc;
try (ByteBufferOutputStream response = new ByteBufferOutputStream(wrapped.length + 4); try (ByteBufferOutputStream bbOut = new ByteBufferOutputStream(4 + size);
DataOutputStream out = new DataOutputStream(response)) { DataOutputStream out = new DataOutputStream(bbOut)) {
out.writeInt(wrapped.length); out.writeInt(size);
out.write(wrapped); pair.getFirst().writeTo(out);
bc = new BufferChain(response.getByteBuffer()); bc = new BufferChain(bbOut.getByteBuffer());
} }
doRespond(() -> bc); doRespond(new RpcResponse() {
@Override
public BufferChain getResponse() {
return bc;
}
@Override
public void done() {
// must switch after sending the connection header response, as the client still uses the
// original SaslClient to unwrap the data we send back
saslServer.switchToCryptoAES(pair.getSecond());
}
});
} catch (IOException ex) { } catch (IOException ex) {
throw new UnsupportedCryptoException(ex.getMessage(), ex); throw new UnsupportedCryptoException(ex.getMessage(), ex);
} }
@ -581,7 +501,9 @@ abstract class ServerRpcConnection implements Closeable {
/** /**
* n * Has the request header and the request param and optionally encoded data buffer all in this * n * Has the request header and the request param and optionally encoded data buffer all in this
* one array. nn * one array.
* <p/>
* Will be overridden in tests.
*/ */
protected void processRequest(ByteBuff buf) throws IOException, InterruptedException { protected void processRequest(ByteBuff buf) throws IOException, InterruptedException {
long totalRequestSize = buf.limit(); long totalRequestSize = buf.limit();
@ -589,14 +511,7 @@ abstract class ServerRpcConnection implements Closeable {
// Here we read in the header. We avoid having pb // Here we read in the header. We avoid having pb
// do its default 4k allocation for CodedInputStream. We force it to use // do its default 4k allocation for CodedInputStream. We force it to use
// backing array. // backing array.
CodedInputStream cis; CodedInputStream cis = createCis(buf);
if (buf.hasArray()) {
cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()).newCodedInput();
} else {
cis = UnsafeByteOperations
.unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput();
}
cis.enableAliasing(true);
int headerSize = cis.readRawVarint32(); int headerSize = cis.readRawVarint32();
offset = cis.getTotalBytesRead(); offset = cis.getTotalBytesRead();
Message.Builder builder = RequestHeader.newBuilder(); Message.Builder builder = RequestHeader.newBuilder();
@ -737,7 +652,7 @@ abstract class ServerRpcConnection implements Closeable {
} }
private void doBadPreambleHandling(String msg, Exception e) throws IOException { private void doBadPreambleHandling(String msg, Exception e) throws IOException {
SimpleRpcServer.LOG.warn(msg); RpcServer.LOG.warn(msg);
doRespond(getErrorResponse(msg, e)); doRespond(getErrorResponse(msg, e));
} }
@ -762,7 +677,7 @@ abstract class ServerRpcConnection implements Closeable {
int version = preambleBuffer.get() & 0xFF; int version = preambleBuffer.get() & 0xFF;
byte authbyte = preambleBuffer.get(); byte authbyte = preambleBuffer.get();
if (version != SimpleRpcServer.CURRENT_VERSION) { if (version != RpcServer.CURRENT_VERSION) {
String msg = getFatalConnectionString(version, authbyte); String msg = getFatalConnectionString(version, authbyte);
doBadPreambleHandling(msg, new WrongVersionException(msg)); doBadPreambleHandling(msg, new WrongVersionException(msg));
return false; return false;
@ -810,34 +725,28 @@ abstract class ServerRpcConnection implements Closeable {
private static class ByteBuffByteInput extends ByteInput { private static class ByteBuffByteInput extends ByteInput {
private ByteBuff buf; private ByteBuff buf;
private int offset;
private int length; private int length;
ByteBuffByteInput(ByteBuff buf, int offset, int length) { ByteBuffByteInput(ByteBuff buf, int length) {
this.buf = buf; this.buf = buf;
this.offset = offset;
this.length = length; this.length = length;
} }
@Override @Override
public byte read(int offset) { public byte read(int offset) {
return this.buf.get(getAbsoluteOffset(offset)); return this.buf.get(offset);
}
private int getAbsoluteOffset(int offset) {
return this.offset + offset;
} }
@Override @Override
public int read(int offset, byte[] out, int outOffset, int len) { public int read(int offset, byte[] out, int outOffset, int len) {
this.buf.get(getAbsoluteOffset(offset), out, outOffset, len); this.buf.get(offset, out, outOffset, len);
return len; return len;
} }
@Override @Override
public int read(int offset, ByteBuffer out) { public int read(int offset, ByteBuffer out) {
int len = out.remaining(); int len = out.remaining();
this.buf.get(out, getAbsoluteOffset(offset), len); this.buf.get(out, offset, len);
return len; return len;
} }

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.hbase.ipc; package org.apache.hadoop.hbase.ipc;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException; import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
@ -28,6 +29,8 @@ import java.util.Iterator;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.apache.hadoop.hbase.HBaseIOException; import org.apache.hadoop.hbase.HBaseIOException;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.hbase.util.Threads; import org.apache.hadoop.hbase.util.Threads;
import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.util.StringUtils;
@ -217,6 +220,28 @@ class SimpleRpcServerResponder extends Thread {
} }
} }
private BufferChain wrapWithSasl(HBaseSaslRpcServer saslServer, BufferChain bc)
throws IOException {
// Looks like no way around this; saslserver wants a byte array. I have to make it one.
// THIS IS A BIG UGLY COPY.
byte[] responseBytes = bc.getBytes();
byte[] token;
// synchronization may be needed since there can be multiple Handler
// threads using saslServer or Crypto AES to wrap responses.
synchronized (saslServer) {
token = saslServer.wrap(responseBytes, 0, responseBytes.length);
}
if (SimpleRpcServer.LOG.isTraceEnabled()) {
SimpleRpcServer.LOG
.trace("Adding saslServer wrapped token of size " + token.length + " as call response.");
}
ByteBuffer[] responseBufs = new ByteBuffer[2];
responseBufs[0] = ByteBuffer.wrap(Bytes.toBytes(token.length));
responseBufs[1] = ByteBuffer.wrap(token);
return new BufferChain(responseBufs);
}
/** /**
* Process the response for this call. You need to have the lock on * Process the response for this call. You need to have the lock on
* {@link org.apache.hadoop.hbase.ipc.SimpleServerRpcConnection#responseWriteLock} * {@link org.apache.hadoop.hbase.ipc.SimpleServerRpcConnection#responseWriteLock}
@ -226,6 +251,9 @@ class SimpleRpcServerResponder extends Thread {
throws IOException { throws IOException {
boolean error = true; boolean error = true;
BufferChain buf = resp.getResponse(); BufferChain buf = resp.getResponse();
if (conn.useWrap) {
buf = wrapWithSasl(conn.saslServer, buf);
}
try { try {
// Send as much data as we can in the non-blocking fashion // Send as much data as we can in the non-blocking fashion
long numBytes = this.simpleRpcServer.channelWrite(conn.channel, buf); long numBytes = this.simpleRpcServer.channelWrite(conn.channel, buf);

View File

@ -17,11 +17,13 @@
*/ */
package org.apache.hadoop.hbase.ipc; package org.apache.hadoop.hbase.ipc;
import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.Socket; import java.net.Socket;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel; import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedDeque;
@ -34,7 +36,11 @@ import org.apache.hadoop.hbase.client.VersionInfoUtil;
import org.apache.hadoop.hbase.exceptions.RequestTooBigException; import org.apache.hadoop.hbase.exceptions.RequestTooBigException;
import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup; import org.apache.hadoop.hbase.ipc.RpcServer.CallCleanup;
import org.apache.hadoop.hbase.nio.ByteBuff; import org.apache.hadoop.hbase.nio.ByteBuff;
import org.apache.hadoop.hbase.nio.SingleByteBuff;
import org.apache.hadoop.hbase.security.HBaseSaslRpcServer;
import org.apache.hadoop.hbase.security.SaslStatus;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.io.BytesWritable;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService; import org.apache.hbase.thirdparty.com.google.protobuf.BlockingService;
@ -63,6 +69,11 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
// If initial preamble with version and magic has been read or not. // If initial preamble with version and magic has been read or not.
private boolean connectionPreambleRead = false; private boolean connectionPreambleRead = false;
private boolean saslContextEstablished;
private ByteBuffer unwrappedData;
// When is this set? FindBugs wants to know! Says NP
private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
boolean useWrap = false;
final ConcurrentLinkedDeque<RpcResponse> responseQueue = new ConcurrentLinkedDeque<>(); final ConcurrentLinkedDeque<RpcResponse> responseQueue = new ConcurrentLinkedDeque<>();
final Lock responseWriteLock = new ReentrantLock(); final Lock responseWriteLock = new ReentrantLock();
@ -142,6 +153,110 @@ class SimpleServerRpcConnection extends ServerRpcConnection {
} }
} }
private void processUnwrappedData(byte[] inBuf) throws IOException, InterruptedException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf));
// Read all RPCs contained in the inBuf, even partial ones
while (true) {
int count;
if (unwrappedDataLengthBuffer.remaining() > 0) {
count = this.rpcServer.channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) {
return;
}
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == RpcClient.PING_CALL_ID) {
if (RpcServer.LOG.isDebugEnabled()) RpcServer.LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = this.rpcServer.channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0) {
return;
}
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(new SingleByteBuff(unwrappedData));
unwrappedData = null;
}
}
}
private void saslReadAndProcess(ByteBuff saslToken) throws IOException, InterruptedException {
if (saslContextEstablished) {
RpcServer.LOG.trace("Read input token of size={} for processing by saslServer.unwrap()",
saslToken.limit());
if (!useWrap) {
processOneRpc(saslToken);
} else {
byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes();
byte[] plaintextData = saslServer.unwrap(b, 0, b.length);
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
processUnwrappedData(plaintextData);
}
} else {
byte[] replyToken;
try {
try {
getOrCreateSaslServer();
} catch (Exception e) {
RpcServer.LOG.error("Error when trying to create instance of HBaseSaslRpcServer "
+ "with sasl provider: " + provider, e);
throw e;
}
RpcServer.LOG.debug("Created SASL server with mechanism={}",
provider.getSaslAuthMethod().getAuthMethod());
RpcServer.LOG.debug(
"Read input token of size={} for processing by saslServer." + "evaluateResponse()",
saslToken.limit());
replyToken = saslServer
.evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes());
} catch (IOException e) {
RpcServer.LOG.debug("Failed to execute SASL handshake", e);
Throwable sendToClient = HBaseSaslRpcServer.unwrap(e);
doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(),
sendToClient.getLocalizedMessage());
this.rpcServer.metrics.authenticationFailure();
String clientIP = this.toString();
// attempting user could be null
RpcServer.AUDITLOG.warn("{}{}: {}", RpcServer.AUTH_FAILED_FOR, clientIP,
saslServer.getAttemptingUser());
throw e;
} finally {
// release the request buffer as we have already unwrapped all its content
callCleanupIfNeeded();
}
if (replyToken != null) {
if (RpcServer.LOG.isDebugEnabled()) {
RpcServer.LOG.debug("Will send token of size " + replyToken.length + " from saslServer.");
}
doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null);
}
if (saslServer.isComplete()) {
String qop = saslServer.getNegotiatedQop();
useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
ugi =
provider.getAuthorizedUgi(saslServer.getAuthorizationID(), this.rpcServer.secretManager);
RpcServer.LOG.debug(
"SASL server context established. Authenticated client: {}. Negotiated QoP is {}", ugi,
qop);
this.rpcServer.metrics.authenticationSuccess();
RpcServer.AUDITLOG.info(RpcServer.AUTH_SUCCESSFUL_FOR + ugi);
saslContextEstablished = true;
}
}
}
/** /**
* Read off the wire. If there is not enough data to read, update the connection state with what * Read off the wire. If there is not enough data to read, update the connection state with what
* we have and returns. * we have and returns.

View File

@ -24,6 +24,7 @@ import java.util.Map;
import javax.security.sasl.Sasl; import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer; import javax.security.sasl.SaslServer;
import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES;
import org.apache.hadoop.hbase.security.provider.AttemptingUserProvidingSaslServer; import org.apache.hadoop.hbase.security.provider.AttemptingUserProvidingSaslServer;
import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvider; import org.apache.hadoop.hbase.security.provider.SaslServerAuthenticationProvider;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
@ -40,6 +41,7 @@ public class HBaseSaslRpcServer {
private final AttemptingUserProvidingSaslServer serverWithProvider; private final AttemptingUserProvidingSaslServer serverWithProvider;
private final SaslServer saslServer; private final SaslServer saslServer;
private CryptoAES cryptoAES;
public HBaseSaslRpcServer(SaslServerAuthenticationProvider provider, public HBaseSaslRpcServer(SaslServerAuthenticationProvider provider,
Map<String, String> saslProps, SecretManager<TokenIdentifier> secretManager) Map<String, String> saslProps, SecretManager<TokenIdentifier> secretManager)
@ -61,16 +63,28 @@ public class HBaseSaslRpcServer {
SaslUtil.safeDispose(saslServer); SaslUtil.safeDispose(saslServer);
} }
public void switchToCryptoAES(CryptoAES cryptoAES) {
this.cryptoAES = cryptoAES;
}
public String getAttemptingUser() { public String getAttemptingUser() {
return serverWithProvider.getAttemptingUser().map(Object::toString).orElse("Unknown"); return serverWithProvider.getAttemptingUser().map(Object::toString).orElse("Unknown");
} }
public byte[] wrap(byte[] buf, int off, int len) throws SaslException { public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
return saslServer.wrap(buf, off, len); if (cryptoAES != null) {
return cryptoAES.wrap(buf, off, len);
} else {
return saslServer.wrap(buf, off, len);
}
} }
public byte[] unwrap(byte[] buf, int off, int len) throws SaslException { public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
return saslServer.unwrap(buf, off, len); if (cryptoAES != null) {
return cryptoAES.unwrap(buf, off, len);
} else {
return saslServer.unwrap(buf, off, len);
}
} }
public String getNegotiatedQop() { public String getNegotiatedQop() {
@ -92,4 +106,18 @@ public class HBaseSaslRpcServer {
} }
return tokenIdentifier; return tokenIdentifier;
} }
/**
* Unwrap InvalidToken exception, otherwise return the one passed in.
*/
public static Throwable unwrap(Throwable e) {
Throwable cause = e;
while (cause != null) {
if (cause instanceof InvalidToken) {
return cause;
}
cause = cause.getCause();
}
return e;
}
} }