HBASE-12684 Add new AsyncRpcClient (Jurriaan Mous)

Signed-off-by: stack <stack@apache.org>
This commit is contained in:
Jurriaan Mous 2015-01-15 13:17:04 +01:00 committed by stack
parent e05341d01d
commit 854f13afa1
18 changed files with 2169 additions and 108 deletions

View File

@ -187,40 +187,32 @@ public abstract class AbstractRpcClient implements RpcClient {
return config.getInt(HConstants.HBASE_CLIENT_IPC_POOL_SIZE, 1); return config.getInt(HConstants.HBASE_CLIENT_IPC_POOL_SIZE, 1);
} }
/** /**
* Make a blocking call. Throws exceptions if there are network problems or if the remote code * Make a blocking call. Throws exceptions if there are network problems or if the remote code
* threw an exception. * threw an exception.
*
* @param ticket Be careful which ticket you pass. A new user will mean a new Connection. * @param ticket Be careful which ticket you pass. A new user will mean a new Connection.
* {@link UserProvider#getCurrent()} makes a new instance of User each time so will be a * {@link UserProvider#getCurrent()} makes a new instance of User each time so
* will be a
* new Connection each time. * new Connection each time.
* @return A pair with the Message response and the Cell data (if any). * @return A pair with the Message response and the Cell data (if any).
*/ */
Message callBlockingMethod(Descriptors.MethodDescriptor md, PayloadCarryingRpcController pcrc, Message callBlockingMethod(Descriptors.MethodDescriptor md, PayloadCarryingRpcController pcrc,
Message param, Message returnType, final User ticket, final InetSocketAddress isa) Message param, Message returnType, final User ticket, final InetSocketAddress isa)
throws ServiceException { throws ServiceException {
if (pcrc == null) {
pcrc = new PayloadCarryingRpcController();
}
long startTime = 0; long startTime = 0;
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
startTime = EnvironmentEdgeManager.currentTime(); startTime = EnvironmentEdgeManager.currentTime();
} }
int callTimeout = 0;
CellScanner cells = null;
if (pcrc != null) {
callTimeout = pcrc.getCallTimeout();
cells = pcrc.cellScanner();
// Clear it here so we don't by mistake try and these cells processing results.
pcrc.setCellScanner(null);
}
Pair<Message, CellScanner> val; Pair<Message, CellScanner> val;
try { try {
val = call(pcrc, md, param, cells, returnType, ticket, isa, callTimeout, val = call(pcrc, md, param, returnType, ticket, isa);
pcrc != null? pcrc.getPriority(): HConstants.NORMAL_QOS);
if (pcrc != null) {
// Shove the results into controller so can be carried across the proxy/pb service void. // Shove the results into controller so can be carried across the proxy/pb service void.
if (val.getSecond() != null) pcrc.setCellScanner(val.getSecond()); pcrc.setCellScanner(val.getSecond());
} else if (val.getSecond() != null) {
throw new ServiceException("Client dropping data on the floor!");
}
if (LOG.isTraceEnabled()) { if (LOG.isTraceEnabled()) {
long callTime = EnvironmentEdgeManager.currentTime() - startTime; long callTime = EnvironmentEdgeManager.currentTime() - startTime;
@ -238,26 +230,22 @@ public abstract class AbstractRpcClient implements RpcClient {
* with the <code>ticket</code> credentials, returning the value. * with the <code>ticket</code> credentials, returning the value.
* Throws exceptions if there are network problems or if the remote code * Throws exceptions if there are network problems or if the remote code
* threw an exception. * threw an exception.
*
* @param ticket Be careful which ticket you pass. A new user will mean a new Connection. * @param ticket Be careful which ticket you pass. A new user will mean a new Connection.
* {@link UserProvider#getCurrent()} makes a new instance of User each time so will be a * {@link UserProvider#getCurrent()} makes a new instance of User each time so
* will be a
* new Connection each time. * new Connection each time.
* @return A pair with the Message response and the Cell data (if any). * @return A pair with the Message response and the Cell data (if any).
* @throws InterruptedException * @throws InterruptedException
* @throws java.io.IOException * @throws java.io.IOException
*/ */
protected abstract Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc, protected abstract Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc,
Descriptors.MethodDescriptor md, Message param, CellScanner cells, Descriptors.MethodDescriptor md, Message param, Message returnType, User ticket,
Message returnType, User ticket, InetSocketAddress addr, int callTimeout, int priority) throws InetSocketAddress isa) throws IOException, InterruptedException;
IOException, InterruptedException;
/**
* Creates a "channel" that can be used by a blocking protobuf service. Useful setting up
* protobuf blocking stubs.
* @return A blocking rpc channel that goes via this rpc client instance.
*/
@Override @Override
public BlockingRpcChannel createBlockingRpcChannel(final ServerName sn, public BlockingRpcChannel createBlockingRpcChannel(final ServerName sn, final User ticket,
final User ticket, int defaultOperationTimeout) { int defaultOperationTimeout) {
return new BlockingRpcChannelImplementation(this, sn, ticket, defaultOperationTimeout); return new BlockingRpcChannelImplementation(this, sn, ticket, defaultOperationTimeout);
} }
@ -269,18 +257,17 @@ public abstract class AbstractRpcClient implements RpcClient {
private final InetSocketAddress isa; private final InetSocketAddress isa;
private final AbstractRpcClient rpcClient; private final AbstractRpcClient rpcClient;
private final User ticket; private final User ticket;
private final int defaultOperationTimeout; private final int channelOperationTimeout;
/** /**
* @param defaultOperationTimeout - the default timeout when no timeout is given * @param channelOperationTimeout - the default timeout when no timeout is given
* by the caller.
*/ */
protected BlockingRpcChannelImplementation(final AbstractRpcClient rpcClient, protected BlockingRpcChannelImplementation(final AbstractRpcClient rpcClient,
final ServerName sn, final User ticket, int defaultOperationTimeout) { final ServerName sn, final User ticket, int channelOperationTimeout) {
this.isa = new InetSocketAddress(sn.getHostname(), sn.getPort()); this.isa = new InetSocketAddress(sn.getHostname(), sn.getPort());
this.rpcClient = rpcClient; this.rpcClient = rpcClient;
this.ticket = ticket; this.ticket = ticket;
this.defaultOperationTimeout = defaultOperationTimeout; this.channelOperationTimeout = channelOperationTimeout;
} }
@Override @Override
@ -289,12 +276,12 @@ public abstract class AbstractRpcClient implements RpcClient {
PayloadCarryingRpcController pcrc; PayloadCarryingRpcController pcrc;
if (controller != null) { if (controller != null) {
pcrc = (PayloadCarryingRpcController) controller; pcrc = (PayloadCarryingRpcController) controller;
if (!pcrc.hasCallTimeout()){ if (!pcrc.hasCallTimeout()) {
pcrc.setCallTimeout(defaultOperationTimeout); pcrc.setCallTimeout(channelOperationTimeout);
} }
} else { } else {
pcrc = new PayloadCarryingRpcController(); pcrc = new PayloadCarryingRpcController();
pcrc.setCallTimeout(defaultOperationTimeout); pcrc.setCallTimeout(channelOperationTimeout);
} }
return this.rpcClient.callBlockingMethod(md, pcrc, param, returnType, this.ticket, this.isa); return this.rpcClient.callBlockingMethod(md, pcrc, param, returnType, this.ticket, this.isa);

View File

@ -0,0 +1,135 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import io.netty.channel.EventLoop;
import io.netty.util.concurrent.DefaultPromise;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.protobuf.ProtobufUtil;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.hbase.util.ExceptionUtil;
import org.apache.hadoop.ipc.RemoteException;
import java.io.IOException;
/**
* Represents an Async Hbase call and its response.
*
* Responses are passed on to its given doneHandler and failures to the rpcController
*/
@InterfaceAudience.Private
public class AsyncCall extends DefaultPromise<Message> {
public static final Log LOG = LogFactory.getLog(AsyncCall.class.getName());
final int id;
final Descriptors.MethodDescriptor method;
final Message param;
final PayloadCarryingRpcController controller;
final Message responseDefaultType;
final long startTime;
final long rpcTimeout;
/**
* Constructor
*
* @param eventLoop for call
* @param connectId connection id
* @param md the method descriptor
* @param param parameters to send to Server
* @param controller controller for response
* @param responseDefaultType the default response type
*/
public AsyncCall(EventLoop eventLoop, int connectId, Descriptors.MethodDescriptor md, Message
param, PayloadCarryingRpcController controller, Message responseDefaultType) {
super(eventLoop);
this.id = connectId;
this.method = md;
this.param = param;
this.controller = controller;
this.responseDefaultType = responseDefaultType;
this.startTime = EnvironmentEdgeManager.currentTime();
this.rpcTimeout = controller.getCallTimeout();
}
/**
* Get the start time
*
* @return start time for the call
*/
public long getStartTime() {
return this.startTime;
}
@Override public String toString() {
return "callId: " + this.id + " methodName: " + this.method.getName() + " param {" +
(this.param != null ? ProtobufUtil.getShortTextFormat(this.param) : "") + "}";
}
/**
* Set success with a cellBlockScanner
*
* @param value to set
* @param cellBlockScanner to set
*/
public void setSuccess(Message value, CellScanner cellBlockScanner) {
if (cellBlockScanner != null) {
controller.setCellScanner(cellBlockScanner);
}
if (LOG.isTraceEnabled()) {
long callTime = EnvironmentEdgeManager.currentTime() - startTime;
LOG.trace("Call: " + method.getName() + ", callTime: " + callTime + "ms");
}
this.setSuccess(value);
}
/**
* Set failed
*
* @param exception to set
*/
public void setFailed(IOException exception) {
if (ExceptionUtil.isInterrupt(exception)) {
exception = ExceptionUtil.asInterrupt(exception);
}
if (exception instanceof RemoteException) {
exception = ((RemoteException) exception).unwrapRemoteException();
}
this.setFailure(exception);
}
/**
* Get the rpc timeout
*
* @return current timeout for this call
*/
public long getRpcTimeout() {
return rpcTimeout;
}
}

View File

@ -0,0 +1,765 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.RpcCallback;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.Timeout;
import io.netty.util.TimerTask;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.exceptions.ConnectionClosingException;
import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos;
import org.apache.hadoop.hbase.protobuf.generated.RPCProtos;
import org.apache.hadoop.hbase.protobuf.generated.TracingProtos;
import org.apache.hadoop.hbase.security.AuthMethod;
import org.apache.hadoop.hbase.security.SaslClientHandler;
import org.apache.hadoop.hbase.security.SaslUtil;
import org.apache.hadoop.hbase.security.SecurityInfo;
import org.apache.hadoop.hbase.security.User;
import org.apache.hadoop.hbase.security.token.AuthenticationTokenSelector;
import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenSelector;
import org.htrace.Span;
import org.htrace.Trace;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeUnit;
/**
* Netty RPC channel
*/
@InterfaceAudience.Private
public class AsyncRpcChannel {
public static final Log LOG = LogFactory.getLog(AsyncRpcChannel.class.getName());
private static final int MAX_SASL_RETRIES = 5;
protected final static Map<AuthenticationProtos.TokenIdentifier.Kind, TokenSelector<? extends
TokenIdentifier>> tokenHandlers = new HashMap<>();
static {
tokenHandlers.put(AuthenticationProtos.TokenIdentifier.Kind.HBASE_AUTH_TOKEN,
new AuthenticationTokenSelector());
}
final AsyncRpcClient client;
// Contains the channel to work with.
// Only exists when connected
private Channel channel;
String name;
final User ticket;
final String serviceName;
final InetSocketAddress address;
ConcurrentSkipListMap<Integer, AsyncCall> calls = new ConcurrentSkipListMap<>();
private int ioFailureCounter = 0;
private int connectFailureCounter = 0;
boolean useSasl;
AuthMethod authMethod;
private int reloginMaxBackoff;
private Token<? extends TokenIdentifier> token;
private String serverPrincipal;
boolean shouldCloseConnection = false;
private IOException closeException;
private Timeout cleanupTimer;
private final TimerTask timeoutTask = new TimerTask() {
@Override public void run(Timeout timeout) throws Exception {
cleanupTimer = null;
cleanupCalls(false);
}
};
/**
* Constructor for netty RPC channel
*
* @param bootstrap to construct channel on
* @param client to connect with
* @param ticket of user which uses connection
* @param serviceName name of service to connect to
* @param address to connect to
*/
public AsyncRpcChannel(Bootstrap bootstrap, final AsyncRpcClient client, User ticket, String
serviceName, InetSocketAddress address) {
this.client = client;
this.ticket = ticket;
this.serviceName = serviceName;
this.address = address;
this.channel = connect(bootstrap).channel();
name = ("IPC Client (" + channel.hashCode() + ") connection to " +
address.toString() +
((ticket == null) ?
" from an unknown user" :
(" from " + ticket.getName())));
}
/**
* Connect to channel
*
* @param bootstrap to connect to
* @return future of connection
*/
private ChannelFuture connect(final Bootstrap bootstrap) {
return bootstrap.remoteAddress(address).connect()
.addListener(new GenericFutureListener<ChannelFuture>() {
@Override
public void operationComplete(final ChannelFuture f) throws Exception {
if (!f.isSuccess()) {
if (f.cause() instanceof SocketException) {
retryOrClose(bootstrap, connectFailureCounter++, f.cause());
} else {
retryOrClose(bootstrap, ioFailureCounter++, f.cause());
}
return;
}
channel = f.channel();
setupAuthorization();
ByteBuf b = channel.alloc().directBuffer(6);
createPreamble(b, authMethod);
channel.writeAndFlush(b).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
if (useSasl) {
UserGroupInformation ticket = AsyncRpcChannel.this.ticket.getUGI();
if (authMethod == AuthMethod.KERBEROS) {
if (ticket != null && ticket.getRealUser() != null) {
ticket = ticket.getRealUser();
}
}
SaslClientHandler saslHandler;
if (ticket == null) {
throw new FatalConnectionException("ticket/user is null");
}
saslHandler = ticket.doAs(new PrivilegedExceptionAction<SaslClientHandler>() {
@Override
public SaslClientHandler run() throws IOException {
return getSaslHandler(bootstrap);
}
});
if (saslHandler != null) {
// Sasl connect is successful. Let's set up Sasl channel handler
channel.pipeline().addFirst(saslHandler);
} else {
// fall back to simple auth because server told us so.
authMethod = AuthMethod.SIMPLE;
useSasl = false;
}
} else {
startHBaseConnection(f.channel());
}
}
});
}
/**
* Start HBase connection
*
* @param ch channel to start connection on
*/
private void startHBaseConnection(Channel ch) {
ch.pipeline()
.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
ch.pipeline().addLast(new AsyncServerResponseHandler(this));
try {
writeChannelHeader(ch).addListener(new GenericFutureListener<ChannelFuture>() {
@Override public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
close(future.cause());
return;
}
for (AsyncCall call : calls.values()) {
writeRequest(call);
}
}
});
} catch (IOException e) {
close(e);
}
}
/**
* Get SASL handler
*
* @param bootstrap to reconnect to
* @return new SASL handler
* @throws java.io.IOException if handler failed to create
*/
private SaslClientHandler getSaslHandler(final Bootstrap bootstrap) throws IOException {
return new SaslClientHandler(authMethod, token, serverPrincipal, client.fallbackAllowed,
client.conf.get("hbase.rpc.protection",
SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()),
new SaslClientHandler.SaslExceptionHandler() {
@Override public void handle(int retryCount, Random random, Throwable cause) {
try {
// Handle Sasl failure. Try to potentially get new credentials
handleSaslConnectionFailure(retryCount, cause, ticket.getUGI());
// Try to reconnect
AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() {
@Override public void run(Timeout timeout) throws Exception {
connect(bootstrap);
}
}, random.nextInt(reloginMaxBackoff) + 1, TimeUnit.MILLISECONDS);
} catch (IOException | InterruptedException e) {
close(e);
}
}
}, new SaslClientHandler.SaslSuccessfulConnectHandler() {
@Override public void onSuccess(Channel channel) {
startHBaseConnection(channel);
}
});
}
/**
* Retry to connect or close
*
* @param bootstrap to connect with
* @param connectCounter amount of tries
* @param e exception of fail
*/
private void retryOrClose(final Bootstrap bootstrap, int connectCounter, Throwable e) {
if (connectCounter < client.maxRetries) {
AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() {
@Override public void run(Timeout timeout) throws Exception {
connect(bootstrap);
}
}, client.failureSleep, TimeUnit.MILLISECONDS);
} else {
client.failedServers.addToFailedServers(address);
close(e);
}
}
/**
* Calls method on channel
* @param method to call
* @param controller to run call with
* @param request to send
* @param responsePrototype to construct response with
*/
public Promise<Message> callMethod(final Descriptors.MethodDescriptor method,
final PayloadCarryingRpcController controller, final Message request,
final Message responsePrototype) {
if (shouldCloseConnection) {
Promise<Message> promise = channel.eventLoop().newPromise();
promise.setFailure(new ConnectException());
return promise;
}
final AsyncCall call = new AsyncCall(channel.eventLoop(), client.callIdCnt.getAndIncrement(),
method, request, controller, responsePrototype);
controller.notifyOnCancel(new RpcCallback<Object>() {
@Override
public void run(Object parameter) {
failCall(call, new IOException("Canceled connection"));
}
});
calls.put(call.id, call);
// Add timeout for cleanup if none is present
if (cleanupTimer == null) {
cleanupTimer = AsyncRpcClient.WHEEL_TIMER.newTimeout(timeoutTask, call.getRpcTimeout(),
TimeUnit.MILLISECONDS);
}
if(channel.isActive()) {
writeRequest(call);
}
return call;
}
/**
* Calls method and returns a promise
* @param method to call
* @param controller to run call with
* @param request to send
* @param responsePrototype for response message
* @return Promise to listen to result
* @throws java.net.ConnectException on connection failures
*/
public Promise<Message> callMethodWithPromise(
final Descriptors.MethodDescriptor method, final PayloadCarryingRpcController controller,
final Message request, final Message responsePrototype) throws ConnectException {
if (shouldCloseConnection || !channel.isOpen()) {
throw new ConnectException();
}
return this.callMethod(method, controller, request, responsePrototype);
}
/**
* Write the channel header
*
* @param channel to write to
* @return future of write
* @throws java.io.IOException on failure to write
*/
private ChannelFuture writeChannelHeader(Channel channel) throws IOException {
RPCProtos.ConnectionHeader.Builder headerBuilder =
RPCProtos.ConnectionHeader.newBuilder().setServiceName(serviceName);
RPCProtos.UserInformation userInfoPB = buildUserInfo(ticket.getUGI(), authMethod);
if (userInfoPB != null) {
headerBuilder.setUserInfo(userInfoPB);
}
if (client.codec != null) {
headerBuilder.setCellBlockCodecClass(client.codec.getClass().getCanonicalName());
}
if (client.compressor != null) {
headerBuilder.setCellBlockCompressorClass(client.compressor.getClass().getCanonicalName());
}
RPCProtos.ConnectionHeader header = headerBuilder.build();
int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header);
ByteBuf b = channel.alloc().directBuffer(totalSize);
b.writeInt(header.getSerializedSize());
b.writeBytes(header.toByteArray());
return channel.writeAndFlush(b);
}
/**
* Write request to channel
*
* @param call to write
*/
private void writeRequest(final AsyncCall call) {
try {
if (shouldCloseConnection) {
return;
}
final RPCProtos.RequestHeader.Builder requestHeaderBuilder = RPCProtos.RequestHeader
.newBuilder();
requestHeaderBuilder.setCallId(call.id)
.setMethodName(call.method.getName()).setRequestParam(call.param != null);
if (Trace.isTracing()) {
Span s = Trace.currentSpan();
requestHeaderBuilder.setTraceInfo(TracingProtos.RPCTInfo.newBuilder().
setParentId(s.getSpanId()).setTraceId(s.getTraceId()));
}
ByteBuffer cellBlock = client.buildCellBlock(call.controller.cellScanner());
if (cellBlock != null) {
final RPCProtos.CellBlockMeta.Builder cellBlockBuilder = RPCProtos.CellBlockMeta
.newBuilder();
cellBlockBuilder.setLength(cellBlock.limit());
requestHeaderBuilder.setCellBlockMeta(cellBlockBuilder.build());
}
// Only pass priority if there one. Let zero be same as no priority.
if (call.controller.getPriority() != 0) {
requestHeaderBuilder.setPriority(call.controller.getPriority());
}
RPCProtos.RequestHeader rh = requestHeaderBuilder.build();
int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(rh, call.param);
if (cellBlock != null) {
totalSize += cellBlock.remaining();
}
ByteBuf b = channel.alloc().directBuffer(4 + totalSize);
try(ByteBufOutputStream out = new ByteBufOutputStream(b)) {
IPCUtil.write(out, rh, call.param, cellBlock);
}
channel.writeAndFlush(b).addListener(new CallWriteListener(this,call));
} catch (IOException e) {
if (!shouldCloseConnection) {
close(e);
}
}
}
/**
* Fail a call
*
* @param call to fail
* @param cause of fail
*/
void failCall(AsyncCall call, IOException cause) {
calls.remove(call.id);
call.setFailed(cause);
}
/**
* Set up server authorization
*
* @throws java.io.IOException if auth setup failed
*/
private void setupAuthorization() throws IOException {
SecurityInfo securityInfo = SecurityInfo.getInfo(serviceName);
this.useSasl = client.userProvider.isHBaseSecurityEnabled();
this.token = null;
if (useSasl && securityInfo != null) {
AuthenticationProtos.TokenIdentifier.Kind tokenKind = securityInfo.getTokenKind();
if (tokenKind != null) {
TokenSelector<? extends TokenIdentifier> tokenSelector = tokenHandlers.get(tokenKind);
if (tokenSelector != null) {
token = tokenSelector
.selectToken(new Text(client.clusterId), ticket.getUGI().getTokens());
} else if (LOG.isDebugEnabled()) {
LOG.debug("No token selector found for type " + tokenKind);
}
}
String serverKey = securityInfo.getServerPrincipal();
if (serverKey == null) {
throw new IOException("Can't obtain server Kerberos config key from SecurityInfo");
}
this.serverPrincipal = SecurityUtil.getServerPrincipal(client.conf.get(serverKey),
address.getAddress().getCanonicalHostName().toLowerCase());
if (LOG.isDebugEnabled()) {
LOG.debug("RPC Server Kerberos principal name for service=" + serviceName + " is "
+ serverPrincipal);
}
}
if (!useSasl) {
authMethod = AuthMethod.SIMPLE;
} else if (token != null) {
authMethod = AuthMethod.DIGEST;
} else {
authMethod = AuthMethod.KERBEROS;
}
if (LOG.isDebugEnabled()) {
LOG.debug("Use " + authMethod + " authentication for service " + serviceName +
", sasl=" + useSasl);
}
reloginMaxBackoff = client.conf.getInt("hbase.security.relogin.maxbackoff", 5000);
}
/**
* Build the user information
*
* @param ugi User Group Information
* @param authMethod Authorization method
* @return UserInformation protobuf
*/
private RPCProtos.UserInformation buildUserInfo(UserGroupInformation ugi, AuthMethod authMethod) {
if (ugi == null || authMethod == AuthMethod.DIGEST) {
// Don't send user for token auth
return null;
}
RPCProtos.UserInformation.Builder userInfoPB = RPCProtos.UserInformation.newBuilder();
if (authMethod == AuthMethod.KERBEROS) {
// Send effective user for Kerberos auth
userInfoPB.setEffectiveUser(ugi.getUserName());
} else if (authMethod == AuthMethod.SIMPLE) {
//Send both effective user and real user for simple auth
userInfoPB.setEffectiveUser(ugi.getUserName());
if (ugi.getRealUser() != null) {
userInfoPB.setRealUser(ugi.getRealUser().getUserName());
}
}
return userInfoPB.build();
}
/**
* Create connection preamble
*
* @param byteBuf to write to
* @param authMethod to write
*/
private void createPreamble(ByteBuf byteBuf, AuthMethod authMethod) {
byteBuf.writeBytes(HConstants.RPC_HEADER);
byteBuf.writeByte(HConstants.RPC_CURRENT_VERSION);
byteBuf.writeByte(authMethod.code);
}
/**
* Close connection
*
* @param e exception on close
*/
public void close(final Throwable e) {
client.removeConnection(ConnectionId.hashCode(ticket,serviceName,address));
// Move closing from the requesting thread to the channel thread
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
if (shouldCloseConnection) {
return;
}
shouldCloseConnection = true;
if (e != null) {
if (e instanceof IOException) {
closeException = (IOException) e;
} else {
closeException = new IOException(e);
}
}
// log the info
if (LOG.isDebugEnabled() && closeException != null) {
LOG.debug(name + ": closing ipc connection to " + address + ": " +
closeException.getMessage());
}
cleanupCalls(true);
channel.disconnect().addListener(ChannelFutureListener.CLOSE);
if (LOG.isDebugEnabled()) {
LOG.debug(name + ": closed");
}
}
});
}
/**
* Clean up calls.
*
* @param cleanAll true if all calls should be cleaned, false for only the timed out calls
*/
public void cleanupCalls(boolean cleanAll) {
// Cancel outstanding timers
if (cleanupTimer != null) {
cleanupTimer.cancel();
cleanupTimer = null;
}
if (cleanAll) {
for (AsyncCall call : calls.values()) {
synchronized (call) {
// Calls can be done on another thread so check before failing them
if(!call.isDone()) {
if (closeException == null) {
failCall(call, new ConnectionClosingException("Call id=" + call.id +
" on server " + address + " aborted: connection is closing"));
} else {
failCall(call, closeException);
}
}
}
}
} else {
for (AsyncCall call : calls.values()) {
long waitTime = EnvironmentEdgeManager.currentTime() - call.getStartTime();
long timeout = call.getRpcTimeout();
if (timeout > 0 && waitTime >= timeout) {
synchronized (call) {
// Calls can be done on another thread so check before failing them
if (!call.isDone()) {
closeException = new CallTimeoutException("Call id=" + call.id +
", waitTime=" + waitTime + ", rpcTimeout=" + timeout);
failCall(call, closeException);
}
}
} else {
// We expect the call to be ordered by timeout. It may not be the case, but stopping
// at the first valid call allows to be sure that we still have something to do without
// spending too much time by reading the full list.
break;
}
}
if (!calls.isEmpty()) {
AsyncCall firstCall = calls.firstEntry().getValue();
final long newTimeout;
long maxWaitTime = EnvironmentEdgeManager.currentTime() - firstCall.getStartTime();
if (maxWaitTime < firstCall.getRpcTimeout()) {
newTimeout = firstCall.getRpcTimeout() - maxWaitTime;
} else {
newTimeout = 0;
}
closeException = null;
cleanupTimer = AsyncRpcClient.WHEEL_TIMER.newTimeout(timeoutTask,
newTimeout, TimeUnit.MILLISECONDS);
}
}
}
/**
* Check if the connection is alive
*
* @return true if alive
*/
public boolean isAlive() {
return channel.isOpen();
}
/**
* Check if user should authenticate over Kerberos
*
* @return true if should be authenticated over Kerberos
* @throws java.io.IOException on failure of check
*/
private synchronized boolean shouldAuthenticateOverKrb() throws IOException {
UserGroupInformation loginUser = UserGroupInformation.getLoginUser();
UserGroupInformation currentUser = UserGroupInformation.getCurrentUser();
UserGroupInformation realUser = currentUser.getRealUser();
return authMethod == AuthMethod.KERBEROS &&
loginUser != null &&
//Make sure user logged in using Kerberos either keytab or TGT
loginUser.hasKerberosCredentials() &&
// relogin only in case it is the login user (e.g. JT)
// or superuser (like oozie).
(loginUser.equals(currentUser) || loginUser.equals(realUser));
}
/**
* If multiple clients with the same principal try to connect
* to the same server at the same time, the server assumes a
* replay attack is in progress. This is a feature of kerberos.
* In order to work around this, what is done is that the client
* backs off randomly and tries to initiate the connection
* again.
* The other problem is to do with ticket expiry. To handle that,
* a relogin is attempted.
* <p>
* The retry logic is governed by the {@link #shouldAuthenticateOverKrb}
* method. In case when the user doesn't have valid credentials, we don't
* need to retry (from cache or ticket). In such cases, it is prudent to
* throw a runtime exception when we receive a SaslException from the
* underlying authentication implementation, so there is no retry from
* other high level (for eg, HCM or HBaseAdmin).
* </p>
*
* @param currRetries retry count
* @param ex exception describing fail
* @param user which is trying to connect
* @throws java.io.IOException if IO fail
* @throws InterruptedException if thread is interrupted
*/
private void handleSaslConnectionFailure(final int currRetries, final Throwable ex,
final UserGroupInformation user) throws IOException, InterruptedException {
user.doAs(new PrivilegedExceptionAction<Void>() {
public Void run() throws IOException, InterruptedException {
if (shouldAuthenticateOverKrb()) {
if (currRetries < MAX_SASL_RETRIES) {
LOG.debug("Exception encountered while connecting to the server : " + ex);
//try re-login
if (UserGroupInformation.isLoginKeytabBased()) {
UserGroupInformation.getLoginUser().reloginFromKeytab();
} else {
UserGroupInformation.getLoginUser().reloginFromTicketCache();
}
// Should reconnect
return null;
} else {
String msg = "Couldn't setup connection for " +
UserGroupInformation.getLoginUser().getUserName() +
" to " + serverPrincipal;
LOG.warn(msg);
throw (IOException) new IOException(msg).initCause(ex);
}
} else {
LOG.warn("Exception encountered while connecting to " +
"the server : " + ex);
}
if (ex instanceof RemoteException) {
throw (RemoteException) ex;
}
if (ex instanceof SaslException) {
String msg = "SASL authentication failed." +
" The most likely cause is missing or invalid credentials." +
" Consider 'kinit'.";
LOG.fatal(msg, ex);
throw new RuntimeException(msg, ex);
}
throw new IOException(ex);
}
});
}
@Override
public String toString() {
return this.address.toString() + "/" + this.serviceName + "/" + this.ticket;
}
/**
* Listens to call writes and fails if write failed
*/
private static final class CallWriteListener implements ChannelFutureListener {
private final AsyncRpcChannel rpcChannel;
private final AsyncCall call;
public CallWriteListener(AsyncRpcChannel asyncRpcChannel, AsyncCall call) {
this.rpcChannel = asyncRpcChannel;
this.call = call;
}
@Override public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
if(!this.call.isDone()) {
if (future.cause() instanceof IOException) {
rpcChannel.failCall(call, (IOException) future.cause());
} else {
rpcChannel.failCall(call, new IOException(future.cause()));
}
}
}
}
}
}

View File

@ -0,0 +1,402 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.RpcCallback;
import com.google.protobuf.RpcChannel;
import com.google.protobuf.RpcController;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.HashedWheelTimer;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.ServerName;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.security.User;
import org.apache.hadoop.hbase.util.JVM;
import org.apache.hadoop.hbase.util.Pair;
import org.apache.hadoop.hbase.util.PoolMap;
import org.apache.hadoop.hbase.util.Threads;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Netty client for the requests and responses
*/
@InterfaceAudience.Private
public class AsyncRpcClient extends AbstractRpcClient {
public static final String CLIENT_MAX_THREADS = "hbase.rpc.client.threads.max";
public static final String USE_NATIVE_TRANSPORT = "hbase.rpc.client.useNativeTransport";
public static final HashedWheelTimer WHEEL_TIMER =
new HashedWheelTimer(100, TimeUnit.MILLISECONDS);
private static final ChannelInitializer<SocketChannel> DEFAULT_CHANNEL_INITIALIZER =
new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
//empty initializer
}
};
protected final AtomicInteger callIdCnt = new AtomicInteger();
private final EventLoopGroup eventLoopGroup;
private final PoolMap<Integer, AsyncRpcChannel> connections;
final FailedServers failedServers;
private final Bootstrap bootstrap;
/**
* Constructor for tests
*
* @param configuration to HBase
* @param clusterId for the cluster
* @param localAddress local address to connect to
* @param channelInitializer for custom channel handlers
*/
@VisibleForTesting
AsyncRpcClient(Configuration configuration, String clusterId, SocketAddress localAddress,
ChannelInitializer<SocketChannel> channelInitializer) {
super(configuration, clusterId, localAddress);
if (LOG.isDebugEnabled()) {
LOG.debug("Starting async Hbase RPC client");
}
// Max amount of threads to use. 0 lets Netty decide based on amount of cores
int maxThreads = conf.getInt(CLIENT_MAX_THREADS, 0);
// Config to enable native transport. Does not seem to be stable at time of implementation
// although it is not extensively tested.
boolean epollEnabled = conf.getBoolean(USE_NATIVE_TRANSPORT, false);
// Use the faster native epoll transport mechanism on linux if enabled
Class<? extends Channel> socketChannelClass;
if (epollEnabled && JVM.isLinux()) {
socketChannelClass = EpollSocketChannel.class;
this.eventLoopGroup =
new EpollEventLoopGroup(maxThreads, Threads.newDaemonThreadFactory("AsyncRpcChannel"));
} else {
socketChannelClass = NioSocketChannel.class;
this.eventLoopGroup =
new NioEventLoopGroup(maxThreads, Threads.newDaemonThreadFactory("AsyncRpcChannel"));
}
this.connections = new PoolMap<>(getPoolType(configuration), getPoolSize(configuration));
this.failedServers = new FailedServers(configuration);
int operationTimeout = configuration.getInt(HConstants.HBASE_CLIENT_OPERATION_TIMEOUT,
HConstants.DEFAULT_HBASE_CLIENT_OPERATION_TIMEOUT);
// Configure the default bootstrap.
this.bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup).channel(socketChannelClass)
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
.option(ChannelOption.TCP_NODELAY, tcpNoDelay)
.option(ChannelOption.SO_KEEPALIVE, tcpKeepAlive)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, operationTimeout);
if (channelInitializer == null) {
channelInitializer = DEFAULT_CHANNEL_INITIALIZER;
}
bootstrap.handler(channelInitializer);
if (localAddress != null) {
bootstrap.localAddress(localAddress);
}
}
/**
* Constructor
*
* @param configuration to HBase
* @param clusterId for the cluster
* @param localAddress local address to connect to
*/
public AsyncRpcClient(Configuration configuration, String clusterId, SocketAddress localAddress) {
this(configuration, clusterId, localAddress, null);
}
/**
* Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol,
* with the <code>ticket</code> credentials, returning the value.
* Throws exceptions if there are network problems or if the remote code
* threw an exception.
*
* @param ticket Be careful which ticket you pass. A new user will mean a new Connection.
* {@link org.apache.hadoop.hbase.security.UserProvider#getCurrent()} makes a new
* instance of User each time so will be a new Connection each time.
* @return A pair with the Message response and the Cell data (if any).
* @throws InterruptedException if call is interrupted
* @throws java.io.IOException if a connection failure is encountered
*/
@Override protected Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc,
Descriptors.MethodDescriptor md, Message param, Message returnType, User ticket,
InetSocketAddress addr) throws IOException, InterruptedException {
final AsyncRpcChannel connection = createRpcChannel(md.getService().getName(), addr, ticket);
Promise<Message> promise = connection.callMethodWithPromise(md, pcrc, param, returnType);
try {
Message response = promise.get();
return new Pair<>(response, pcrc.cellScanner());
} catch (ExecutionException e) {
if (e.getCause() instanceof IOException) {
throw (IOException) e.getCause();
} else {
throw new IOException(e.getCause());
}
}
}
/**
* Call method async
*/
private void callMethod(Descriptors.MethodDescriptor md, final PayloadCarryingRpcController pcrc,
Message param, Message returnType, User ticket, InetSocketAddress addr,
final RpcCallback<Message> done) {
final AsyncRpcChannel connection;
try {
connection = createRpcChannel(md.getService().getName(), addr, ticket);
connection.callMethod(md, pcrc, param, returnType).addListener(
new GenericFutureListener<Future<Message>>() {
@Override
public void operationComplete(Future<Message> future) throws Exception {
if(!future.isSuccess()){
Throwable cause = future.cause();
if (cause instanceof IOException) {
pcrc.setFailed((IOException) cause);
}else{
pcrc.setFailed(new IOException(cause));
}
}else{
try {
done.run(future.get());
}catch (ExecutionException e){
Throwable cause = e.getCause();
if (cause instanceof IOException) {
pcrc.setFailed((IOException) cause);
}else{
pcrc.setFailed(new IOException(cause));
}
}catch (InterruptedException e){
pcrc.setFailed(new IOException(e));
}
}
}
});
} catch (StoppedRpcClientException|FailedServerException e) {
pcrc.setFailed(e);
}
}
/**
* Close netty
*/
public void close() {
if (LOG.isDebugEnabled()) {
LOG.debug("Stopping async HBase RPC client");
}
synchronized (connections) {
for (AsyncRpcChannel conn : connections.values()) {
conn.close(null);
}
}
eventLoopGroup.shutdownGracefully();
}
/**
* Create a cell scanner
*
* @param cellBlock to create scanner for
* @return CellScanner
* @throws java.io.IOException on error on creation cell scanner
*/
public CellScanner createCellScanner(byte[] cellBlock) throws IOException {
return ipcUtil.createCellScanner(this.codec, this.compressor, cellBlock);
}
/**
* Build cell block
*
* @param cells to create block with
* @return ByteBuffer with cells
* @throws java.io.IOException if block creation fails
*/
public ByteBuffer buildCellBlock(CellScanner cells) throws IOException {
return ipcUtil.buildCellBlock(this.codec, this.compressor, cells);
}
/**
* Creates an RPC client
*
* @param serviceName name of servicce
* @param location to connect to
* @param ticket for current user
* @return new RpcChannel
* @throws StoppedRpcClientException when Rpc client is stopped
* @throws FailedServerException if server failed
*/
private AsyncRpcChannel createRpcChannel(String serviceName, InetSocketAddress location,
User ticket) throws StoppedRpcClientException, FailedServerException {
if (this.eventLoopGroup.isShuttingDown() || this.eventLoopGroup.isShutdown()) {
throw new StoppedRpcClientException();
}
// Check if server is failed
if (this.failedServers.isFailedServer(location)) {
if (LOG.isDebugEnabled()) {
LOG.debug("Not trying to connect to " + location +
" this server is in the failed servers list");
}
throw new FailedServerException(
"This server is in the failed servers list: " + location);
}
int hashCode = ConnectionId.hashCode(ticket,serviceName,location);
AsyncRpcChannel rpcChannel;
synchronized (connections) {
rpcChannel = connections.get(hashCode);
if (rpcChannel == null) {
rpcChannel = new AsyncRpcChannel(this.bootstrap, this, ticket, serviceName, location);
connections.put(hashCode, rpcChannel);
}
}
return rpcChannel;
}
/**
* Interrupt the connections to the given ip:port server. This should be called if the server
* is known as actually dead. This will not prevent current operation to be retried, and,
* depending on their own behavior, they may retry on the same server. This can be a feature,
* for example at startup. In any case, they're likely to get connection refused (if the
* process died) or no route to host: i.e. there next retries should be faster and with a
* safe exception.
*
* @param sn server to cancel connections for
*/
@Override
public void cancelConnections(ServerName sn) {
synchronized (connections) {
for (AsyncRpcChannel rpcChannel : connections.values()) {
if (rpcChannel.isAlive() &&
rpcChannel.address.getPort() == sn.getPort() &&
rpcChannel.address.getHostName().contentEquals(sn.getHostname())) {
LOG.info("The server on " + sn.toString() +
" is dead - stopping the connection " + rpcChannel.toString());
rpcChannel.close(null);
}
}
}
}
/**
* Remove connection from pool
*
* @param connectionHashCode of connection
*/
public void removeConnection(int connectionHashCode) {
synchronized (connections) {
this.connections.remove(connectionHashCode);
}
}
/**
* Creates a "channel" that can be used by a protobuf service. Useful setting up
* protobuf stubs.
*
* @param sn server name describing location of server
* @param user which is to use the connection
* @param rpcTimeout default rpc operation timeout
*
* @return A rpc channel that goes via this rpc client instance.
* @throws IOException when channel could not be created
*/
public RpcChannel createRpcChannel(final ServerName sn, final User user, int rpcTimeout) {
return new RpcChannelImplementation(this, sn, user, rpcTimeout);
}
/**
* Blocking rpc channel that goes via hbase rpc.
*/
@VisibleForTesting
public static class RpcChannelImplementation implements RpcChannel {
private final InetSocketAddress isa;
private final AsyncRpcClient rpcClient;
private final User ticket;
private final int channelOperationTimeout;
/**
* @param channelOperationTimeout - the default timeout when no timeout is given
*/
protected RpcChannelImplementation(final AsyncRpcClient rpcClient,
final ServerName sn, final User ticket, int channelOperationTimeout) {
this.isa = new InetSocketAddress(sn.getHostname(), sn.getPort());
this.rpcClient = rpcClient;
this.ticket = ticket;
this.channelOperationTimeout = channelOperationTimeout;
}
@Override
public void callMethod(Descriptors.MethodDescriptor md, RpcController controller,
Message param, Message returnType, RpcCallback<Message> done) {
PayloadCarryingRpcController pcrc;
if (controller != null) {
pcrc = (PayloadCarryingRpcController) controller;
if (!pcrc.hasCallTimeout()) {
pcrc.setCallTimeout(channelOperationTimeout);
}
} else {
pcrc = new PayloadCarryingRpcController();
pcrc.setCallTimeout(channelOperationTimeout);
}
this.rpcClient.callMethod(md, pcrc, param, returnType, this.ticket, this.isa, done);
}
}
}

View File

@ -0,0 +1,130 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.ipc;
import com.google.protobuf.Message;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.protobuf.generated.RPCProtos;
import org.apache.hadoop.ipc.RemoteException;
import java.io.IOException;
/**
* Handles Hbase responses
*/
@InterfaceAudience.Private
public class AsyncServerResponseHandler extends ChannelInboundHandlerAdapter {
public static final Log LOG = LogFactory.getLog(AsyncServerResponseHandler.class.getName());
private final AsyncRpcChannel channel;
/**
* Constructor
*
* @param channel on which this response handler operates
*/
public AsyncServerResponseHandler(AsyncRpcChannel channel) {
this.channel = channel;
}
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf inBuffer = (ByteBuf) msg;
ByteBufInputStream in = new ByteBufInputStream(inBuffer);
if (channel.shouldCloseConnection) {
return;
}
int totalSize = inBuffer.readableBytes();
try {
// Read the header
RPCProtos.ResponseHeader responseHeader = RPCProtos.ResponseHeader.parseDelimitedFrom(in);
int id = responseHeader.getCallId();
AsyncCall call = channel.calls.get(id);
if (call == null) {
// So we got a response for which we have no corresponding 'call' here on the client-side.
// We probably timed out waiting, cleaned up all references, and now the server decides
// to return a response. There is nothing we can do w/ the response at this stage. Clean
// out the wire of the response so its out of the way and we can get other responses on
// this connection.
int readSoFar = IPCUtil.getTotalSizeWhenWrittenDelimited(responseHeader);
int whatIsLeftToRead = totalSize - readSoFar;
// This is done through a Netty ByteBuf which has different behavior than InputStream.
// It does not return number of bytes read but will update pointer internally and throws an
// exception when too many bytes are to be skipped.
inBuffer.skipBytes(whatIsLeftToRead);
return;
}
if (responseHeader.hasException()) {
RPCProtos.ExceptionResponse exceptionResponse = responseHeader.getException();
RemoteException re = createRemoteException(exceptionResponse);
if (exceptionResponse.getExceptionClassName().
equals(FatalConnectionException.class.getName())) {
channel.close(re);
} else {
channel.failCall(call, re);
}
} else {
Message value = null;
// Call may be null because it may have timedout and been cleaned up on this side already
if (call.responseDefaultType != null) {
Message.Builder builder = call.responseDefaultType.newBuilderForType();
builder.mergeDelimitedFrom(in);
value = builder.build();
}
CellScanner cellBlockScanner = null;
if (responseHeader.hasCellBlockMeta()) {
int size = responseHeader.getCellBlockMeta().getLength();
byte[] cellBlock = new byte[size];
inBuffer.readBytes(cellBlock, 0, cellBlock.length);
cellBlockScanner = channel.client.createCellScanner(cellBlock);
}
call.setSuccess(value, cellBlockScanner);
}
channel.calls.remove(id);
} catch (IOException e) {
// Treat this as a fatal condition and close this connection
channel.close(e);
} finally {
inBuffer.release();
channel.cleanupCalls(false);
}
}
/**
* @param e Proto exception
* @return RemoteException made from passed <code>e</code>
*/
private RemoteException createRemoteException(final RPCProtos.ExceptionResponse e) {
String innerExceptionClassName = e.getExceptionClassName();
boolean doNotRetry = e.getDoNotRetry();
return e.hasHostname() ?
// If a hostname then add it to the RemoteWithExtrasException
new RemoteWithExtrasException(innerExceptionClassName, e.getStackTrace(), e.getHostname(),
e.getPort(), doNotRetry) :
new RemoteWithExtrasException(innerExceptionClassName, e.getStackTrace(), doNotRetry);
}
}

View File

@ -28,10 +28,10 @@ import java.net.InetSocketAddress;
*/ */
@InterfaceAudience.Private @InterfaceAudience.Private
public class ConnectionId { public class ConnectionId {
final InetSocketAddress address;
final User ticket;
private static final int PRIME = 16777619; private static final int PRIME = 16777619;
final User ticket;
final String serviceName; final String serviceName;
final InetSocketAddress address;
public ConnectionId(User ticket, String serviceName, InetSocketAddress address) { public ConnectionId(User ticket, String serviceName, InetSocketAddress address) {
this.address = address; this.address = address;
@ -70,9 +70,12 @@ public class ConnectionId {
@Override // simply use the default Object#hashcode() ? @Override // simply use the default Object#hashcode() ?
public int hashCode() { public int hashCode() {
int hashcode = (address.hashCode() + return hashCode(ticket,serviceName,address);
PRIME * (PRIME * this.serviceName.hashCode() ^ }
public static int hashCode(User ticket, String serviceName, InetSocketAddress address){
return (address.hashCode() +
PRIME * (PRIME * serviceName.hashCode() ^
(ticket == null ? 0 : ticket.hashCode()))); (ticket == null ? 0 : ticket.hashCode())));
return hashcode;
} }
} }

View File

@ -42,7 +42,7 @@ public class PayloadCarryingRpcController
*/ */
// Currently only multi call makes use of this. Eventually this should be only way to set // Currently only multi call makes use of this. Eventually this should be only way to set
// priority. // priority.
private int priority = 0; private int priority = HConstants.NORMAL_QOS;
/** /**
* They are optionally set on construction, cleared after we make the call, and then optionally * They are optionally set on construction, cleared after we make the call, and then optionally

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.security.User; import org.apache.hadoop.hbase.security.User;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException;
/** /**
* Interface for RpcClient implementations so ConnectionManager can handle it. * Interface for RpcClient implementations so ConnectionManager can handle it.
@ -56,9 +57,15 @@ import java.io.Closeable;
* Creates a "channel" that can be used by a blocking protobuf service. Useful setting up * Creates a "channel" that can be used by a blocking protobuf service. Useful setting up
* protobuf blocking stubs. * protobuf blocking stubs.
* *
* @param sn server name describing location of server
* @param user which is to use the connection
* @param rpcTimeout default rpc operation timeout
*
* @return A blocking rpc channel that goes via this rpc client instance. * @return A blocking rpc channel that goes via this rpc client instance.
* @throws IOException when channel could not be created
*/ */
public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, int rpcTimeout); public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user,
int rpcTimeout) throws IOException;
/** /**
* Interrupt the connections to the given server. This should be called if the server * Interrupt the connections to the given server. This should be called if the server
@ -67,6 +74,7 @@ import java.io.Closeable;
* for example at startup. In any case, they're likely to get connection refused (if the * for example at startup. In any case, they're likely to get connection refused (if the
* process died) or no route to host: i.e. their next retries should be faster and with a * process died) or no route to host: i.e. their next retries should be faster and with a
* safe exception. * safe exception.
* @param sn server location to cancel connections of
*/ */
public void cancelConnections(ServerName sn); public void cancelConnections(ServerName sn);

View File

@ -59,8 +59,7 @@ public final class RpcClientFactory {
public static RpcClient createClient(Configuration conf, String clusterId, public static RpcClient createClient(Configuration conf, String clusterId,
SocketAddress localAddr) { SocketAddress localAddr) {
String rpcClientClass = String rpcClientClass =
conf.get(CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, conf.get(CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, AsyncRpcClient.class.getName());
RpcClientImpl.class.getName());
return ReflectionUtils.instantiateWithCustomCtor( return ReflectionUtils.instantiateWithCustomCtor(
rpcClientClass, rpcClientClass,
new Class[] { Configuration.class, String.class, SocketAddress.class }, new Class[] { Configuration.class, String.class, SocketAddress.class },

View File

@ -787,9 +787,9 @@ public class RpcClientImpl extends AbstractRpcClient {
// up the reading on occasion (the passed in stream is not buffered yet). // up the reading on occasion (the passed in stream is not buffered yet).
// Preamble is six bytes -- 'HBas' + VERSION + AUTH_CODE // Preamble is six bytes -- 'HBas' + VERSION + AUTH_CODE
int rpcHeaderLen = HConstants.RPC_HEADER.array().length; int rpcHeaderLen = HConstants.RPC_HEADER.length;
byte [] preamble = new byte [rpcHeaderLen + 2]; byte [] preamble = new byte [rpcHeaderLen + 2];
System.arraycopy(HConstants.RPC_HEADER.array(), 0, preamble, 0, rpcHeaderLen); System.arraycopy(HConstants.RPC_HEADER, 0, preamble, 0, rpcHeaderLen);
preamble[rpcHeaderLen] = HConstants.RPC_CURRENT_VERSION; preamble[rpcHeaderLen] = HConstants.RPC_CURRENT_VERSION;
preamble[rpcHeaderLen + 1] = authMethod.code; preamble[rpcHeaderLen + 1] = authMethod.code;
outStream.write(preamble); outStream.write(preamble);
@ -1120,14 +1120,6 @@ public class RpcClientImpl extends AbstractRpcClient {
} }
} }
Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc,
MethodDescriptor md, Message param, CellScanner cells,
Message returnType, User ticket, InetSocketAddress addr, int rpcTimeout)
throws InterruptedException, IOException {
return
call(pcrc, md, param, cells, returnType, ticket, addr, rpcTimeout, HConstants.NORMAL_QOS);
}
/** Make a call, passing <code>param</code>, to the IPC server running at /** Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol, * <code>address</code> which is servicing the <code>protocol</code> protocol,
* with the <code>ticket</code> credentials, returning the value. * with the <code>ticket</code> credentials, returning the value.
@ -1140,21 +1132,22 @@ public class RpcClientImpl extends AbstractRpcClient {
* @throws InterruptedException * @throws InterruptedException
* @throws IOException * @throws IOException
*/ */
@Override
protected Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc, MethodDescriptor md, protected Pair<Message, CellScanner> call(PayloadCarryingRpcController pcrc, MethodDescriptor md,
Message param, CellScanner cells, Message param, Message returnType, User ticket, InetSocketAddress addr)
Message returnType, User ticket, InetSocketAddress addr, int callTimeout, int priority)
throws IOException, InterruptedException { throws IOException, InterruptedException {
final Call call = new Call( if (pcrc == null) {
this.callIdCnt.getAndIncrement(), pcrc = new PayloadCarryingRpcController();
md, param, cells, returnType, callTimeout); }
CellScanner cells = pcrc.cellScanner();
final Connection connection = getConnection(ticket, call, addr, this.codec, this.compressor); final Call call = new Call(this.callIdCnt.getAndIncrement(), md, param, cells, returnType,
pcrc.getCallTimeout());
final Connection connection = getConnection(ticket, call, addr);
final CallFuture cts; final CallFuture cts;
if (connection.callSender != null) { if (connection.callSender != null) {
cts = connection.callSender.sendCall(call, priority, Trace.currentSpan()); cts = connection.callSender.sendCall(call, pcrc.getPriority(), Trace.currentSpan());
if (pcrc != null) {
pcrc.notifyOnCancel(new RpcCallback<Object>() { pcrc.notifyOnCancel(new RpcCallback<Object>() {
@Override @Override
public void run(Object parameter) { public void run(Object parameter) {
@ -1166,11 +1159,9 @@ public class RpcClientImpl extends AbstractRpcClient {
call.callComplete(); call.callComplete();
return new Pair<Message, CellScanner>(call.response, call.cells); return new Pair<Message, CellScanner>(call.response, call.cells);
} }
}
} else { } else {
cts = null; cts = null;
connection.tracedWriteRequest(call, priority, Trace.currentSpan()); connection.tracedWriteRequest(call, pcrc.getPriority(), Trace.currentSpan());
} }
while (!call.done) { while (!call.done) {
@ -1265,8 +1256,7 @@ public class RpcClientImpl extends AbstractRpcClient {
* Get a connection from the pool, or create a new one and add it to the * Get a connection from the pool, or create a new one and add it to the
* pool. Connections to a given host/port are reused. * pool. Connections to a given host/port are reused.
*/ */
protected Connection getConnection(User ticket, Call call, InetSocketAddress addr, protected Connection getConnection(User ticket, Call call, InetSocketAddress addr)
final Codec codec, final CompressionCodec compressor)
throws IOException { throws IOException {
if (!running.get()) throw new StoppedRpcClientException(); if (!running.get()) throw new StoppedRpcClientException();
Connection connection; Connection connection;

View File

@ -42,8 +42,12 @@ public class TimeLimitedRpcController implements RpcController {
private IOException exception; private IOException exception;
public Integer getCallTimeout() { public int getCallTimeout() {
if (callTimeout != null) {
return callTimeout; return callTimeout;
} else {
return 0;
}
} }
public void setCallTimeout(int callTimeout) { public void setCallTimeout(int callTimeout) {

View File

@ -117,7 +117,7 @@ public class HBaseSaslRpcClient {
throw new IOException( throw new IOException(
"Failed to specify server's Kerberos principal name"); "Failed to specify server's Kerberos principal name");
} }
String names[] = SaslUtil.splitKerberosName(serverPrincipal); String[] names = SaslUtil.splitKerberosName(serverPrincipal);
if (names.length != 3) { if (names.length != 3) {
throw new IOException( throw new IOException(
"Kerberos principal does not have the expected format: " "Kerberos principal does not have the expected format: "

View File

@ -0,0 +1,353 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Random;
/**
* Handles Sasl connections
*/
@InterfaceAudience.Private
public class SaslClientHandler extends ChannelDuplexHandler {
public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
private final boolean fallbackAllowed;
/**
* Used for client or server's token to send or receive from each other.
*/
private final SaslClient saslClient;
private final SaslExceptionHandler exceptionHandler;
private final SaslSuccessfulConnectHandler successfulConnectHandler;
private byte[] saslToken;
private boolean firstRead = true;
private int retryCount = 0;
private Random random;
/**
* Constructor
*
* @param method auth method
* @param token for Sasl
* @param serverPrincipal Server's Kerberos principal name
* @param fallbackAllowed True if server may also fall back to less secure connection
* @param rpcProtection Quality of protection. Integrity or privacy
* @param exceptionHandler handler for exceptions
* @param successfulConnectHandler handler for succesful connects
* @throws java.io.IOException if handler could not be created
*/
public SaslClientHandler(AuthMethod method, Token<? extends TokenIdentifier> token,
String serverPrincipal, boolean fallbackAllowed, String rpcProtection,
SaslExceptionHandler exceptionHandler, SaslSuccessfulConnectHandler successfulConnectHandler)
throws IOException {
this.fallbackAllowed = fallbackAllowed;
this.exceptionHandler = exceptionHandler;
this.successfulConnectHandler = successfulConnectHandler;
SaslUtil.initSaslProperties(rpcProtection);
switch (method) {
case DIGEST:
if (LOG.isDebugEnabled())
LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
+ " client to authenticate to service at " + token.getService());
saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
break;
case KERBEROS:
if (LOG.isDebugEnabled()) {
LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
+ " client. Server's Kerberos principal name is " + serverPrincipal);
}
if (serverPrincipal == null || serverPrincipal.isEmpty()) {
throw new IOException("Failed to specify server's Kerberos principal name");
}
String[] names = SaslUtil.splitKerberosName(serverPrincipal);
if (names.length != 3) {
throw new IOException(
"Kerberos principal does not have the expected format: " + serverPrincipal);
}
saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
names[0], names[1]);
break;
default:
throw new IOException("Unknown authentication method " + method);
}
if (saslClient == null)
throw new IOException("Unable to find SASL client implementation");
}
/**
* Create a Digest Sasl client
*
* @param mechanismNames names of mechanisms
* @param saslDefaultRealm default realm for sasl
* @param saslClientCallbackHandler handler for the client
* @return new SaslClient
* @throws java.io.IOException if creation went wrong
*/
protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
CallbackHandler saslClientCallbackHandler) throws IOException {
return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
saslClientCallbackHandler);
}
/**
* Create Kerberos client
*
* @param mechanismNames names of mechanisms
* @param userFirstPart first part of username
* @param userSecondPart second part of username
* @return new SaslClient
* @throws java.io.IOException if fails
*/
protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
String userSecondPart) throws IOException {
return Sasl
.createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
null);
}
@Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
saslClient.dispose();
}
@Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.saslToken = new byte[0];
if (saslClient.hasInitialResponse()) {
saslToken = saslClient.evaluateChallenge(saslToken);
}
if (saslToken != null) {
writeSaslToken(ctx, saslToken);
if (LOG.isDebugEnabled()) {
LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
}
}
}
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf in = (ByteBuf) msg;
// If not complete, try to negotiate
if (!saslClient.isComplete()) {
while (!saslClient.isComplete() && in.isReadable()) {
readStatus(in);
int len = in.readInt();
if (firstRead) {
firstRead = false;
if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
if (!fallbackAllowed) {
throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
+ "client is configured to only allow secure connections.");
}
if (LOG.isDebugEnabled()) {
LOG.debug("Server asks us to fall back to simple auth.");
}
saslClient.dispose();
ctx.pipeline().remove(this);
successfulConnectHandler.onSuccess(ctx.channel());
return;
}
}
saslToken = new byte[len];
if (LOG.isDebugEnabled())
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
in.readBytes(saslToken);
saslToken = saslClient.evaluateChallenge(saslToken);
if (saslToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
writeSaslToken(ctx, saslToken);
}
}
if (saslClient.isComplete()) {
String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
if (LOG.isDebugEnabled()) {
LOG.debug("SASL client context established. Negotiated QoP: " + qop);
}
boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
if (!useWrap) {
ctx.pipeline().remove(this);
}
successfulConnectHandler.onSuccess(ctx.channel());
}
}
// Normal wrapped reading
else {
try {
int length = in.readInt();
if (LOG.isDebugEnabled()) {
LOG.debug("Actual length is " + length);
}
saslToken = new byte[length];
in.readBytes(saslToken);
} catch (IndexOutOfBoundsException e) {
return;
}
try {
ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
ctx.fireChannelRead(b);
} catch (SaslException se) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
LOG.debug("Ignoring SASL exception", ignored);
}
throw se;
}
}
}
/**
* Write SASL token
*
* @param ctx to write to
* @param saslToken to write
*/
private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
b.writeInt(saslToken.length);
b.writeBytes(saslToken, 0, saslToken.length);
ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
@Override public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
exceptionCaught(ctx, future.cause());
}
}
});
}
/**
* Get the read status
*
* @param inStream to read
* @throws org.apache.hadoop.ipc.RemoteException if status was not success
*/
private static void readStatus(ByteBuf inStream) throws RemoteException {
int status = inStream.readInt(); // read status
if (status != SaslStatus.SUCCESS.state) {
throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
inStream.toString(Charset.forName("UTF-8")));
}
}
@Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
saslClient.dispose();
ctx.close();
if (this.random == null) {
this.random = new Random();
}
exceptionHandler.handle(this.retryCount++, this.random, cause);
}
@Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
// If not complete, try to negotiate
if (!saslClient.isComplete()) {
super.write(ctx, msg, promise);
} else {
ByteBuf in = (ByteBuf) msg;
try {
saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
} catch (SaslException se) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
LOG.debug("Ignoring SASL exception", ignored);
}
promise.setFailure(se);
}
if (saslToken != null) {
ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
out.writeInt(saslToken.length);
out.writeBytes(saslToken, 0, saslToken.length);
ctx.write(out).addListener(new ChannelFutureListener() {
@Override public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
exceptionCaught(ctx, future.cause());
}
}
});
saslToken = null;
}
}
}
/**
* Handler for exceptions during Sasl connection
*/
public interface SaslExceptionHandler {
/**
* Handle the exception
*
* @param retryCount current retry count
* @param random to create new backoff with
* @param cause of fail
*/
public void handle(int retryCount, Random random, Throwable cause);
}
/**
* Handler for successful connects
*/
public interface SaslSuccessfulConnectHandler {
/**
* Runs on success
*
* @param channel which is successfully authenticated
*/
public void onSuccess(Channel channel);
}
}

View File

@ -62,7 +62,7 @@ public final class HConstants {
/** /**
* The first four bytes of Hadoop RPC connections * The first four bytes of Hadoop RPC connections
*/ */
public static final ByteBuffer RPC_HEADER = ByteBuffer.wrap("HBas".getBytes()); public static final byte[] RPC_HEADER = new byte[] { 'H', 'B', 'a', 's' };
public static final byte RPC_CURRENT_VERSION = 0; public static final byte RPC_CURRENT_VERSION = 0;
// HFileBlock constants. // HFileBlock constants.

View File

@ -71,6 +71,15 @@ public class JVM {
return (ibmvendor ? linux : true); return (ibmvendor ? linux : true);
} }
/**
* Check if the OS is linux.
*
* @return whether this is linux or not.
*/
public static boolean isLinux() {
return linux;
}
/** /**
* Check if the finish() method of GZIPOutputStream is broken * Check if the finish() method of GZIPOutputStream is broken
* *

View File

@ -44,6 +44,7 @@ import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
@ -1412,9 +1413,9 @@ public class RpcServer implements RpcServerInterface {
int count; int count;
// Check for 'HBas' magic. // Check for 'HBas' magic.
this.dataLengthBuffer.flip(); this.dataLengthBuffer.flip();
if (!HConstants.RPC_HEADER.equals(dataLengthBuffer)) { if (!Arrays.equals(HConstants.RPC_HEADER, dataLengthBuffer.array())) {
return doBadPreambleHandling("Expected HEADER=" + return doBadPreambleHandling("Expected HEADER=" +
Bytes.toStringBinary(HConstants.RPC_HEADER.array()) + Bytes.toStringBinary(HConstants.RPC_HEADER) +
" but received HEADER=" + Bytes.toStringBinary(dataLengthBuffer.array()) + " but received HEADER=" + Bytes.toStringBinary(dataLengthBuffer.array()) +
" from " + toString()); " from " + toString());
} }

View File

@ -33,9 +33,18 @@ import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import com.google.protobuf.BlockingRpcChannel;
import com.google.protobuf.RpcCallback;
import com.google.protobuf.RpcChannel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.SocketChannel;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
@ -44,10 +53,13 @@ import org.apache.hadoop.hbase.CellScannable;
import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.CellScanner;
import org.apache.hadoop.hbase.CellUtil; import org.apache.hadoop.hbase.CellUtil;
import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.HBaseTestingUtility;
import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.HRegionInfo; import org.apache.hadoop.hbase.HRegionInfo;
import org.apache.hadoop.hbase.KeyValue; import org.apache.hadoop.hbase.KeyValue;
import org.apache.hadoop.hbase.KeyValueUtil; import org.apache.hadoop.hbase.KeyValueUtil;
import org.apache.hadoop.hbase.ServerName;
import org.apache.hadoop.hbase.Waiter;
import org.apache.hadoop.hbase.testclassification.RPCTests; import org.apache.hadoop.hbase.testclassification.RPCTests;
import org.apache.hadoop.hbase.testclassification.SmallTests; import org.apache.hadoop.hbase.testclassification.SmallTests;
import org.apache.hadoop.hbase.client.Put; import org.apache.hadoop.hbase.client.Put;
@ -91,7 +103,10 @@ import com.google.protobuf.ServiceException;
*/ */
@Category({RPCTests.class, SmallTests.class}) @Category({RPCTests.class, SmallTests.class})
public class TestIPC { public class TestIPC {
private final static HBaseTestingUtility TEST_UTIL = new HBaseTestingUtility();
public static final Log LOG = LogFactory.getLog(TestIPC.class); public static final Log LOG = LogFactory.getLog(TestIPC.class);
static byte [] CELL_BYTES = Bytes.toBytes("xyz"); static byte [] CELL_BYTES = Bytes.toBytes("xyz");
static Cell CELL = new KeyValue(CELL_BYTES, CELL_BYTES, CELL_BYTES, CELL_BYTES); static Cell CELL = new KeyValue(CELL_BYTES, CELL_BYTES, CELL_BYTES, CELL_BYTES);
static byte [] BIG_CELL_BYTES = new byte [10 * 1024]; static byte [] BIG_CELL_BYTES = new byte [10 * 1024];
@ -191,8 +206,8 @@ public class TestIPC {
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
final String message = "hello"; final String message = "hello";
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage(message).build(); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage(message).build();
Pair<Message, CellScanner> r = client.call(null, md, param, null, Pair<Message, CellScanner> r = client.call(null, md, param,
md.getOutputType().toProto(), User.getCurrent(), address, 0); md.getOutputType().toProto(), User.getCurrent(), address);
assertTrue(r.getSecond() == null); assertTrue(r.getSecond() == null);
// Silly assertion that the message is in the returned pb. // Silly assertion that the message is in the returned pb.
assertTrue(r.getFirst().toString().contains(message)); assertTrue(r.getFirst().toString().contains(message));
@ -202,6 +217,44 @@ public class TestIPC {
} }
} }
/**
* Ensure we do not HAVE TO HAVE a codec.
*
* @throws InterruptedException
* @throws IOException
*/
@Test public void testNoCodecAsync() throws InterruptedException, IOException, ServiceException {
Configuration conf = HBaseConfiguration.create();
AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null) {
@Override Codec getCodec() {
return null;
}
};
TestRpcServer rpcServer = new TestRpcServer();
try {
rpcServer.start();
InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
final String message = "hello";
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage(message).build();
BlockingRpcChannel channel = client
.createBlockingRpcChannel(ServerName.valueOf(address.getHostName(), address.getPort(),
System.currentTimeMillis()), User.getCurrent(), 0);
PayloadCarryingRpcController controller = new PayloadCarryingRpcController();
Message response =
channel.callBlockingMethod(md, controller, param, md.getOutputType().toProto());
assertTrue(controller.cellScanner() == null);
// Silly assertion that the message is in the returned pb.
assertTrue(response.toString().contains(message));
} finally {
client.close();
rpcServer.stop();
}
}
/** /**
* It is hard to verify the compression is actually happening under the wraps. Hope that if * It is hard to verify the compression is actually happening under the wraps. Hope that if
* unsupported, we'll get an exception out of some time (meantime, have to trace it manually * unsupported, we'll get an exception out of some time (meantime, have to trace it manually
@ -213,13 +266,17 @@ public class TestIPC {
*/ */
@Test @Test
public void testCompressCellBlock() public void testCompressCellBlock()
throws IOException, InterruptedException, SecurityException, NoSuchMethodException { throws IOException, InterruptedException, SecurityException, NoSuchMethodException,
ServiceException {
Configuration conf = new Configuration(HBaseConfiguration.create()); Configuration conf = new Configuration(HBaseConfiguration.create());
conf.set("hbase.client.rpc.compressor", GzipCodec.class.getCanonicalName()); conf.set("hbase.client.rpc.compressor", GzipCodec.class.getCanonicalName());
doSimpleTest(conf, new RpcClientImpl(conf, HConstants.CLUSTER_ID_DEFAULT)); doSimpleTest(new RpcClientImpl(conf, HConstants.CLUSTER_ID_DEFAULT));
// Another test for the async client
doAsyncSimpleTest(new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null));
} }
private void doSimpleTest(final Configuration conf, final RpcClientImpl client) private void doSimpleTest(final RpcClientImpl client)
throws InterruptedException, IOException { throws InterruptedException, IOException {
TestRpcServer rpcServer = new TestRpcServer(); TestRpcServer rpcServer = new TestRpcServer();
List<Cell> cells = new ArrayList<Cell>(); List<Cell> cells = new ArrayList<Cell>();
@ -230,8 +287,11 @@ public class TestIPC {
InetSocketAddress address = rpcServer.getListenerAddress(); InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
Pair<Message, CellScanner> r = client.call(null, md, param, CellUtil.createCellScanner(cells),
md.getOutputType().toProto(), User.getCurrent(), address, 0); PayloadCarryingRpcController pcrc =
new PayloadCarryingRpcController(CellUtil.createCellScanner(cells));
Pair<Message, CellScanner> r = client
.call(pcrc, md, param, md.getOutputType().toProto(), User.getCurrent(), address);
int index = 0; int index = 0;
while (r.getSecond().advance()) { while (r.getSecond().advance()) {
assertTrue(CELL.equals(r.getSecond().current())); assertTrue(CELL.equals(r.getSecond().current()));
@ -244,6 +304,42 @@ public class TestIPC {
} }
} }
private void doAsyncSimpleTest(final AsyncRpcClient client)
throws InterruptedException, IOException, ServiceException {
TestRpcServer rpcServer = new TestRpcServer();
List<Cell> cells = new ArrayList<Cell>();
int count = 3;
for (int i = 0; i < count; i++)
cells.add(CELL);
try {
rpcServer.start();
InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
PayloadCarryingRpcController pcrc =
new PayloadCarryingRpcController(CellUtil.createCellScanner(cells));
BlockingRpcChannel channel = client.createBlockingRpcChannel(
ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()),
User.getCurrent(), 0);
channel.callBlockingMethod(md, pcrc, param, md.getOutputType().toProto());
CellScanner cellScanner = pcrc.cellScanner();
int index = 0;
while (cellScanner.advance()) {
assertTrue(CELL.equals(cellScanner.current()));
index++;
}
assertEquals(count, index);
} finally {
client.close();
rpcServer.stop();
}
}
@Test @Test
public void testRTEDuringConnectionSetup() throws Exception { public void testRTEDuringConnectionSetup() throws Exception {
Configuration conf = HBaseConfiguration.create(); Configuration conf = HBaseConfiguration.create();
@ -264,7 +360,7 @@ public class TestIPC {
InetSocketAddress address = rpcServer.getListenerAddress(); InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
client.call(null, md, param, null, null, User.getCurrent(), address, 0); client.call(null, md, param, null, User.getCurrent(), address);
fail("Expected an exception to have been thrown!"); fail("Expected an exception to have been thrown!");
} catch (Exception e) { } catch (Exception e) {
LOG.info("Caught expected exception: " + e.toString()); LOG.info("Caught expected exception: " + e.toString());
@ -275,6 +371,147 @@ public class TestIPC {
} }
} }
@Test
public void testRTEDuringAsyncBlockingConnectionSetup() throws Exception {
Configuration conf = HBaseConfiguration.create();
TestRpcServer rpcServer = new TestRpcServer();
AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null,
new ChannelInitializer<SocketChannel>() {
@Override protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addFirst(new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
promise.setFailure(new RuntimeException("Injected fault"));
}
});
}
});
try {
rpcServer.start();
InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
BlockingRpcChannel channel = client.createBlockingRpcChannel(
ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()),
User.getCurrent(), 0);
channel.callBlockingMethod(md, new PayloadCarryingRpcController(), param,
md.getOutputType().toProto());
fail("Expected an exception to have been thrown!");
} catch (Exception e) {
LOG.info("Caught expected exception: " + e.toString());
assertTrue(StringUtils.stringifyException(e).contains("Injected fault"));
} finally {
client.close();
rpcServer.stop();
}
}
@Test
public void testRTEDuringAsyncConnectionSetup() throws Exception {
Configuration conf = HBaseConfiguration.create();
TestRpcServer rpcServer = new TestRpcServer();
AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null,
new ChannelInitializer<SocketChannel>() {
@Override protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addFirst(new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
promise.setFailure(new RuntimeException("Injected fault"));
}
});
}
});
try {
rpcServer.start();
InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
RpcChannel channel = client.createRpcChannel(
ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()),
User.getCurrent(), 0);
final AtomicBoolean done = new AtomicBoolean(false);
PayloadCarryingRpcController controller = new PayloadCarryingRpcController();
controller.notifyOnFail(new RpcCallback<IOException>() {
@Override
public void run(IOException e) {
done.set(true);
LOG.info("Caught expected exception: " + e.toString());
assertTrue(StringUtils.stringifyException(e).contains("Injected fault"));
}
});
channel.callMethod(md, controller, param,
md.getOutputType().toProto(), new RpcCallback<Message>() {
@Override
public void run(Message parameter) {
done.set(true);
fail("Expected an exception to have been thrown!");
}
});
TEST_UTIL.waitFor(1000, new Waiter.Predicate<Exception>() {
@Override
public boolean evaluate() throws Exception {
return done.get();
}
});
} finally {
client.close();
rpcServer.stop();
}
}
@Test
public void testAsyncConnectionSetup() throws Exception {
Configuration conf = HBaseConfiguration.create();
TestRpcServer rpcServer = new TestRpcServer();
AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null);
try {
rpcServer.start();
InetSocketAddress address = rpcServer.getListenerAddress();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
RpcChannel channel = client.createRpcChannel(
ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()),
User.getCurrent(), 0);
final AtomicBoolean done = new AtomicBoolean(false);
channel.callMethod(md, new PayloadCarryingRpcController(), param,
md.getOutputType().toProto(), new RpcCallback<Message>() {
@Override
public void run(Message parameter) {
done.set(true);
}
});
TEST_UTIL.waitFor(1000, new Waiter.Predicate<Exception>() {
@Override
public boolean evaluate() throws Exception {
return done.get();
}
});
} finally {
client.close();
rpcServer.stop();
}
}
/** Tests that the rpc scheduler is called when requests arrive. */ /** Tests that the rpc scheduler is called when requests arrive. */
@Test @Test
public void testRpcScheduler() throws IOException, InterruptedException { public void testRpcScheduler() throws IOException, InterruptedException {
@ -288,8 +525,43 @@ public class TestIPC {
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
client.call(null, md, param, CellUtil.createCellScanner(ImmutableList.of(CELL)), client.call(
md.getOutputType().toProto(), User.getCurrent(), rpcServer.getListenerAddress(), 0); new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL))),
md, param, md.getOutputType().toProto(), User.getCurrent(),
rpcServer.getListenerAddress());
}
verify(scheduler, times(10)).dispatch((CallRunner) anyObject());
} finally {
rpcServer.stop();
verify(scheduler).stop();
}
}
/**
* Tests that the rpc scheduler is called when requests arrive.
*/
@Test
public void testRpcSchedulerAsync()
throws IOException, InterruptedException, ServiceException {
RpcScheduler scheduler = spy(new FifoRpcScheduler(CONF, 1));
RpcServer rpcServer = new TestRpcServer(scheduler);
verify(scheduler).init((RpcScheduler.Context) anyObject());
AbstractRpcClient client = new AsyncRpcClient(CONF, HConstants.CLUSTER_ID_DEFAULT, null);
try {
rpcServer.start();
verify(scheduler).start();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
ServerName serverName = ServerName.valueOf(rpcServer.getListenerAddress().getHostName(),
rpcServer.getListenerAddress().getPort(), System.currentTimeMillis());
for (int i = 0; i < 10; i++) {
BlockingRpcChannel channel = client.createBlockingRpcChannel(
serverName, User.getCurrent(), 0);
channel.callBlockingMethod(md,
new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL))),
param, md.getOutputType().toProto());
} }
verify(scheduler, times(10)).dispatch((CallRunner) anyObject()); verify(scheduler, times(10)).dispatch((CallRunner) anyObject());
} finally { } finally {
@ -341,9 +613,10 @@ public class TestIPC {
// ReflectionUtils.printThreadInfo(new PrintWriter(System.out), // ReflectionUtils.printThreadInfo(new PrintWriter(System.out),
// "Thread dump " + Thread.currentThread().getName()); // "Thread dump " + Thread.currentThread().getName());
} }
CellScanner cellScanner = CellUtil.createCellScanner(cells); PayloadCarryingRpcController pcrc =
new PayloadCarryingRpcController(CellUtil.createCellScanner(cells));
Pair<Message, CellScanner> response = Pair<Message, CellScanner> response =
client.call(null, md, builder.build(), cellScanner, param, user, address, 0); client.call(pcrc, md, builder.build(), param, user, address);
/* /*
int count = 0; int count = 0;
while (p.getSecond().advance()) { while (p.getSecond().advance()) {

View File

@ -17,13 +17,13 @@
*/ */
package org.apache.hadoop.hbase.ipc; package org.apache.hadoop.hbase.ipc;
import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.io.IOException; import com.google.protobuf.BlockingService;
import java.net.InetSocketAddress; import com.google.protobuf.Descriptors.MethodDescriptor;
import java.util.ArrayList; import com.google.protobuf.Message;
import java.util.List; import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
@ -49,13 +49,12 @@ import org.junit.Test;
import org.junit.experimental.categories.Category; import org.junit.experimental.categories.Category;
import org.mockito.Mockito; import org.mockito.Mockito;
import com.google.common.collect.ImmutableList; import java.io.IOException;
import com.google.common.collect.Lists; import java.net.InetSocketAddress;
import com.google.protobuf.BlockingService; import java.util.ArrayList;
import com.google.protobuf.Descriptors.MethodDescriptor; import java.util.List;
import com.google.protobuf.Message;
import com.google.protobuf.RpcController; import static org.mockito.Mockito.mock;
import com.google.protobuf.ServiceException;
@Category({RPCTests.class, SmallTests.class}) @Category({RPCTests.class, SmallTests.class})
public class TestRpcHandlerException { public class TestRpcHandlerException {
@ -178,8 +177,11 @@ public class TestRpcHandlerException {
rpcServer.start(); rpcServer.start();
MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo");
EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build();
client.call(null, md, param, CellUtil.createCellScanner(ImmutableList.of(CELL)), md PayloadCarryingRpcController controller =
.getOutputType().toProto(), User.getCurrent(), rpcServer.getListenerAddress(), 0); new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL)));
client.call(controller, md, param, md.getOutputType().toProto(), User.getCurrent(),
rpcServer.getListenerAddress());
} catch (Throwable e) { } catch (Throwable e) {
assert(abortable.isAborted() == true); assert(abortable.isAborted() == true);
} finally { } finally {