HBASE-15830 SASL encryption doesn't work with AsyncRpcChannelImpl (Colin Ma)

This commit is contained in:
Gary Helmling 2016-05-26 21:57:28 -07:00
parent b89d88a193
commit da0d74cd27
3 changed files with 82 additions and 23 deletions

View File

@ -231,6 +231,26 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel {
} }
} }
private void startConnectionWithEncryption(Channel ch) {
// for rpc encryption, the order of ChannelInboundHandler should be:
// LengthFieldBasedFrameDecoder->SaslClientHandler->LengthFieldBasedFrameDecoder
// Don't skip the first 4 bytes for length in beforeUnwrapDecoder,
// SaslClientHandler will handler this
ch.pipeline().addFirst("beforeUnwrapDecoder",
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 0));
ch.pipeline().addLast("afterUnwrapDecoder",
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
ch.pipeline().addLast(new AsyncServerResponseHandler(this));
List<AsyncCall> callsToWrite;
synchronized (pendingCalls) {
connected = true;
callsToWrite = new ArrayList<AsyncCall>(pendingCalls.values());
}
for (AsyncCall call : callsToWrite) {
writeRequest(call);
}
}
/** /**
* Get SASL handler * Get SASL handler
* @param bootstrap to reconnect to * @param bootstrap to reconnect to
@ -243,6 +263,7 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel {
client.fallbackAllowed, client.fallbackAllowed,
client.conf.get("hbase.rpc.protection", client.conf.get("hbase.rpc.protection",
SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()), SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()),
getChannelHeaderBytes(authMethod),
new SaslClientHandler.SaslExceptionHandler() { new SaslClientHandler.SaslExceptionHandler() {
@Override @Override
public void handle(int retryCount, Random random, Throwable cause) { public void handle(int retryCount, Random random, Throwable cause) {
@ -261,6 +282,11 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel {
public void onSuccess(Channel channel) { public void onSuccess(Channel channel) {
startHBaseConnection(channel); startHBaseConnection(channel);
} }
@Override
public void onSaslProtectionSucess(Channel channel) {
startConnectionWithEncryption(channel);
}
}); });
} }
@ -341,6 +367,25 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel {
* @throws java.io.IOException on failure to write * @throws java.io.IOException on failure to write
*/ */
private ChannelFuture writeChannelHeader(Channel channel) throws IOException { private ChannelFuture writeChannelHeader(Channel channel) throws IOException {
RPCProtos.ConnectionHeader header = getChannelHeader(authMethod);
int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header);
ByteBuf b = channel.alloc().directBuffer(totalSize);
b.writeInt(header.getSerializedSize());
b.writeBytes(header.toByteArray());
return channel.writeAndFlush(b);
}
private byte[] getChannelHeaderBytes(AuthMethod authMethod) {
RPCProtos.ConnectionHeader header = getChannelHeader(authMethod);
ByteBuffer b = ByteBuffer.allocate(header.getSerializedSize() + 4);
b.putInt(header.getSerializedSize());
b.put(header.toByteArray());
return b.array();
}
private RPCProtos.ConnectionHeader getChannelHeader(AuthMethod authMethod) {
RPCProtos.ConnectionHeader.Builder headerBuilder = RPCProtos.ConnectionHeader.newBuilder() RPCProtos.ConnectionHeader.Builder headerBuilder = RPCProtos.ConnectionHeader.newBuilder()
.setServiceName(serviceName); .setServiceName(serviceName);
@ -357,16 +402,7 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel {
} }
headerBuilder.setVersionInfo(ProtobufUtil.getVersionInfo()); headerBuilder.setVersionInfo(ProtobufUtil.getVersionInfo());
RPCProtos.ConnectionHeader header = headerBuilder.build(); return headerBuilder.build();
int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header);
ByteBuf b = channel.alloc().directBuffer(totalSize);
b.writeInt(header.getSerializedSize());
b.writeBytes(header.toByteArray());
return channel.writeAndFlush(b);
} }
/** /**

View File

@ -39,6 +39,7 @@ import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.Map; import java.util.Map;
@ -63,6 +64,7 @@ public class SaslClientHandler extends ChannelDuplexHandler {
private final SaslExceptionHandler exceptionHandler; private final SaslExceptionHandler exceptionHandler;
private final SaslSuccessfulConnectHandler successfulConnectHandler; private final SaslSuccessfulConnectHandler successfulConnectHandler;
private byte[] saslToken; private byte[] saslToken;
private byte[] connectionHeader;
private boolean firstRead = true; private boolean firstRead = true;
private int retryCount = 0; private int retryCount = 0;
@ -80,10 +82,11 @@ public class SaslClientHandler extends ChannelDuplexHandler {
*/ */
public SaslClientHandler(UserGroupInformation ticket, AuthMethod method, public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed, Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
String rpcProtection, SaslExceptionHandler exceptionHandler, String rpcProtection, byte[] connectionHeader, SaslExceptionHandler exceptionHandler,
SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException { SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
this.ticket = ticket; this.ticket = ticket;
this.fallbackAllowed = fallbackAllowed; this.fallbackAllowed = fallbackAllowed;
this.connectionHeader = connectionHeader;
this.exceptionHandler = exceptionHandler; this.exceptionHandler = exceptionHandler;
this.successfulConnectHandler = successfulConnectHandler; this.successfulConnectHandler = successfulConnectHandler;
@ -225,8 +228,13 @@ public class SaslClientHandler extends ChannelDuplexHandler {
if (!useWrap) { if (!useWrap) {
ctx.pipeline().remove(this); ctx.pipeline().remove(this);
successfulConnectHandler.onSuccess(ctx.channel());
} else {
byte[] wrappedCH = saslClient.wrap(connectionHeader, 0, connectionHeader.length);
// write connection header
writeSaslToken(ctx, wrappedCH);
successfulConnectHandler.onSaslProtectionSucess(ctx.channel());
} }
successfulConnectHandler.onSuccess(ctx.channel());
} }
} }
// Normal wrapped reading // Normal wrapped reading
@ -303,9 +311,11 @@ public class SaslClientHandler extends ChannelDuplexHandler {
super.write(ctx, msg, promise); super.write(ctx, msg, promise);
} else { } else {
ByteBuf in = (ByteBuf) msg; ByteBuf in = (ByteBuf) msg;
byte[] unwrapped = new byte[in.readableBytes()];
in.readBytes(unwrapped);
try { try {
saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes()); saslToken = saslClient.wrap(unwrapped, 0, unwrapped.length);
} catch (SaslException se) { } catch (SaslException se) {
try { try {
saslClient.dispose(); saslClient.dispose();
@ -355,5 +365,12 @@ public class SaslClientHandler extends ChannelDuplexHandler {
* @param channel which is successfully authenticated * @param channel which is successfully authenticated
*/ */
public void onSuccess(Channel channel); public void onSuccess(Channel channel);
/**
* Runs on success if data protection used in Sasl
*
* @param channel which is successfully authenticated
*/
public void onSaslProtectionSucess(Channel channel);
} }
} }

View File

@ -36,6 +36,7 @@ import java.util.concurrent.ThreadLocalRandom;
import com.google.protobuf.RpcController; import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException; import com.google.protobuf.ServiceException;
import org.apache.commons.lang.RandomStringUtils;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.hbase.Cell; import org.apache.hadoop.hbase.Cell;
@ -217,6 +218,12 @@ public abstract class AbstractTestSecureIPC {
setRpcProtection("integrity,authentication", "privacy,authentication"); setRpcProtection("integrity,authentication", "privacy,authentication");
callRpcService(User.create(ugi)); callRpcService(User.create(ugi));
setRpcProtection("integrity,authentication", "integrity,authentication");
callRpcService(User.create(ugi));
setRpcProtection("privacy,authentication", "privacy,authentication");
callRpcService(User.create(ugi));
} }
@Test @Test
@ -302,18 +309,17 @@ public abstract class AbstractTestSecureIPC {
@Override @Override
public void run() { public void run() {
String result;
try { try {
result = stub.echo(null, TestProtos.EchoRequestProto.newBuilder().setMessage(String.valueOf( int[] messageSize = new int[] {100, 1000, 10000};
ThreadLocalRandom.current().nextInt())).build()).getMessage(); for (int i = 0; i < messageSize.length; i++) {
} catch (ServiceException e) { String input = RandomStringUtils.random(messageSize[i]);
throw new RuntimeException(e); String result = stub.echo(null, TestProtos.EchoRequestProto.newBuilder()
} .setMessage(input).build()).getMessage();
if (results != null) { assertEquals(input, result);
synchronized (results) {
results.add(result);
}
} }
} catch (ServiceException e) {
throw new RuntimeException(e);
}
} }
} }
} }