HADOOP-6419. Adds SASL based authentication to RPC. Contributed by Kan Zhang.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@905860 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Devaraj Das 2010-02-03 01:30:25 +00:00
parent fe0ddc03e1
commit 940389afce
19 changed files with 1713 additions and 80 deletions

View File

@ -48,6 +48,9 @@ Trunk (unreleased changes)
upon login. The tokens are read from a file specified in the
environment variable. (ddas)
HADOOP-6419. Adds SASL based authentication to RPC.
(Kan Zhang via ddas)
IMPROVEMENTS
HADOOP-6283. Improve the exception messages thrown by

View File

@ -33,6 +33,8 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.net.NetUtils;
import org.apache.avro.*;
@ -192,10 +194,13 @@ class AvroRpcEngine implements RpcEngine {
* port and address. */
public RPC.Server getServer(Class iface, Object impl, String bindAddress,
int port, int numHandlers, boolean verbose,
Configuration conf) throws IOException {
Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager
) throws IOException {
return ENGINE.getServer(TunnelProtocol.class,
new TunnelResponder(iface, impl),
bindAddress, port, numHandlers, verbose, conf);
bindAddress, port, numHandlers, verbose, conf,
secretManager);
}
}

View File

@ -31,7 +31,9 @@ import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FilterInputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.PrivilegedExceptionAction;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map.Entry;
@ -44,11 +46,19 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.util.ReflectionUtils;
/** A client for an IPC service. IPC calls take a single {@link Writable} as a
@ -196,8 +206,13 @@ public class Client {
* socket: responses may be delivered out of order. */
private class Connection extends Thread {
private InetSocketAddress server; // server ip:port
private String serverPrincipal; // server's krb5 principal name
private ConnectionHeader header; // connection header
private ConnectionId remoteId; // connection id
private final ConnectionId remoteId; // connection id
private final AuthMethod authMethod; // authentication method
private final boolean useSasl;
private Token<? extends TokenIdentifier> token;
private SaslRpcClient saslRpcClient;
private Socket socket = null; // connected socket
private DataInputStream in;
@ -221,6 +236,42 @@ public class Client {
Class<?> protocol = remoteId.getProtocol();
header =
new ConnectionHeader(protocol == null ? null : protocol.getName(), ticket);
this.useSasl = UserGroupInformation.isSecurityEnabled();
if (useSasl && protocol != null) {
TokenInfo tokenInfo = protocol.getAnnotation(TokenInfo.class);
if (tokenInfo != null) {
TokenSelector<? extends TokenIdentifier> tokenSelector = null;
try {
tokenSelector = tokenInfo.value().newInstance();
} catch (InstantiationException e) {
throw new IOException(e.toString());
} catch (IllegalAccessException e) {
throw new IOException(e.toString());
}
InetSocketAddress addr = remoteId.getAddress();
token = tokenSelector.selectToken(new Text(addr.getAddress()
.getHostAddress() + ":" + addr.getPort()),
ticket.getTokens());
}
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
if (krbInfo != null) {
String serverKey = krbInfo.value();
if (serverKey != null) {
serverPrincipal = conf.get(serverKey);
}
}
}
if (!useSasl) {
authMethod = AuthMethod.SIMPLE;
} else if (token != null) {
authMethod = AuthMethod.DIGEST;
} else {
authMethod = AuthMethod.KERBEROS;
}
if (LOG.isDebugEnabled())
LOG.debug("Use " + authMethod + " authentication for protocol "
+ protocol.getSimpleName());
this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " +
remoteId.getAddress().toString() +
@ -302,11 +353,20 @@ public class Client {
}
}
private synchronized void disposeSasl() {
if (saslRpcClient != null) {
try {
saslRpcClient.dispose();
} catch (IOException ignored) {
}
}
}
/** Connect to the server and set up the I/O streams. It then sends
* a header to the server and starts
* the connection thread that waits for responses.
*/
private synchronized void setupIOstreams() {
private synchronized void setupIOstreams() throws InterruptedException {
if (socket != null || shouldCloseConnection.get()) {
return;
}
@ -334,15 +394,33 @@ public class Client {
handleConnectionFailure(ioFailures++, maxRetries, ie);
}
}
InputStream inStream = NetUtils.getInputStream(socket);
OutputStream outStream = NetUtils.getOutputStream(socket);
writeRpcHeader(outStream);
if (useSasl) {
final InputStream in2 = inStream;
final OutputStream out2 = outStream;
remoteId.getTicket().doAs(new PrivilegedExceptionAction<Object>() {
@Override
public Object run() throws IOException {
saslRpcClient = new SaslRpcClient(authMethod, token,
serverPrincipal);
saslRpcClient.saslConnect(in2, out2);
return null;
}
});
inStream = saslRpcClient.getInputStream(inStream);
outStream = saslRpcClient.getOutputStream(outStream);
}
if (doPing) {
this.in = new DataInputStream(new BufferedInputStream
(new PingInputStream(NetUtils.getInputStream(socket))));
(new PingInputStream(inStream)));
} else {
this.in = new DataInputStream(new BufferedInputStream
(NetUtils.getInputStream(socket)));
(inStream));
}
this.out = new DataOutputStream
(new BufferedOutputStream(NetUtils.getOutputStream(socket)));
(new BufferedOutputStream(outStream));
writeHeader();
// update last activity time
@ -396,14 +474,20 @@ public class Client {
". Already tried " + curRetries + " time(s).");
}
/* Write the header for each connection
/* Write the RPC header */
private void writeRpcHeader(OutputStream outStream) throws IOException {
DataOutputStream out = new DataOutputStream(new BufferedOutputStream(outStream));
// Write out the header, version and authentication method
out.write(Server.HEADER.array());
out.write(Server.CURRENT_VERSION);
authMethod.write(out);
out.flush();
}
/* Write the protocol header for each connection
* Out is not synchronized because only the first thread does this.
*/
private void writeHeader() throws IOException {
// Write out the header and version
out.write(Server.HEADER.array());
out.write(Server.CURRENT_VERSION);
// Write out the ConnectionHeader
DataOutputBuffer buf = new DataOutputBuffer();
header.write(buf);
@ -575,6 +659,7 @@ public class Client {
// close the streams and therefore the socket
IOUtils.closeStream(out);
IOUtils.closeStream(in);
disposeSasl();
// clean up all calls
if (closeException == null) {
@ -815,7 +900,7 @@ public class Client {
*/
@Deprecated
public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
throws IOException {
throws IOException, InterruptedException {
return call(params, addresses, null, null);
}
@ -825,7 +910,7 @@ public class Client {
* contains nulls for calls that timed out or errored. */
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket)
throws IOException {
throws IOException, InterruptedException {
if (addresses.length == 0) return new Writable[0];
ParallelResults results = new ParallelResults(params.length);
@ -859,7 +944,7 @@ public class Client {
Class<?> protocol,
UserGroupInformation ticket,
Call call)
throws IOException {
throws IOException, InterruptedException {
if (!running.get()) {
// the client is stopped
throw new IOException("The client is stopped");

View File

@ -37,6 +37,8 @@ import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
import org.apache.hadoop.util.ReflectionUtils;
@ -254,7 +256,7 @@ public class RPC {
@Deprecated
public static Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs, Configuration conf)
throws IOException {
throws IOException, InterruptedException {
return call(method, params, addrs, null, conf);
}
@ -262,7 +264,7 @@ public class RPC {
public static Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf)
throws IOException {
throws IOException, InterruptedException {
return getProtocolEngine(method.getDeclaringClass(), conf)
.call(method, params, addrs, ticket, conf);
@ -288,7 +290,7 @@ public class RPC {
final boolean verbose, Configuration conf)
throws IOException {
return getServer(instance.getClass(), // use impl class for protocol
instance, bindAddress, port, numHandlers, false, conf);
instance, bindAddress, port, numHandlers, false, conf, null);
}
/** Construct a server for a protocol implementation instance. */
@ -296,19 +298,34 @@ public class RPC {
Object instance, String bindAddress,
int port, Configuration conf)
throws IOException {
return getServer(protocol, instance, bindAddress, port, 1, false, conf);
return getServer(protocol, instance, bindAddress, port, 1, false, conf, null);
}
/** Construct a server for a protocol implementation instance. */
/** Construct a server for a protocol implementation instance.
* @deprecated secretManager should be passed.
*/
@Deprecated
public static Server getServer(Class protocol,
Object instance, String bindAddress, int port,
int numHandlers,
boolean verbose, Configuration conf)
throws IOException {
return getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
conf, null);
}
/** Construct a server for a protocol implementation instance. */
public static Server getServer(Class<?> protocol,
Object instance, String bindAddress, int port,
int numHandlers,
boolean verbose, Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
return getProtocolEngine(protocol, conf)
.getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
conf);
conf, secretManager);
}
/** An RPC Server. */
@ -316,8 +333,9 @@ public class RPC {
protected Server(String bindAddress, int port,
Class<? extends Writable> paramClass, int handlerCount,
Configuration conf, String serverName) throws IOException {
super(bindAddress, port, paramClass, handlerCount, conf, serverName);
Configuration conf, String serverName,
SecretManager<? extends TokenIdentifier> secretManager) throws IOException {
super(bindAddress, port, paramClass, handlerCount, conf, serverName, secretManager);
}
}

View File

@ -24,6 +24,8 @@ import java.net.InetSocketAddress;
import javax.net.SocketFactory;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.Configuration;
/** An RPC implementation. */
@ -41,11 +43,13 @@ interface RpcEngine {
/** Expert: Make multiple, parallel calls to a set of servers. */
Object[] call(Method method, Object[][] params, InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf)
throws IOException;
throws IOException, InterruptedException;
/** Construct a server for a protocol implementation instance. */
RPC.Server getServer(Class protocol, Object instance, String bindAddress,
int port, int numHandlers, boolean verbose,
Configuration conf) throws IOException;
Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager
) throws IOException;
}

View File

@ -32,6 +32,7 @@ import java.net.SocketException;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
@ -39,7 +40,6 @@ import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
@ -52,15 +52,26 @@ import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.metrics.RpcMetrics;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler;
import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils;
@ -80,7 +91,8 @@ public abstract class Server {
// 1 : Introduce ping and server does not throw away RPCs
// 3 : Introduce the protocol into the RPC connection header
public static final byte CURRENT_VERSION = 3;
// 4 : Introduced SASL security layer
public static final byte CURRENT_VERSION = 4;
/**
* How many calls/handler are allowed in the queue.
@ -158,6 +170,7 @@ public abstract class Server {
protected RpcMetrics rpcMetrics;
private Configuration conf;
private SecretManager<TokenIdentifier> secretManager;
private int maxQueueSize;
private int socketSendBufferSize;
@ -431,7 +444,7 @@ public abstract class Server {
if (count < 0) {
if (LOG.isDebugEnabled())
LOG.debug(getName() + ": disconnecting client " +
c.getHostAddress() + ". Number of active connections: "+
c + ". Number of active connections: "+
numConnections);
closeConnection(c);
c = null;
@ -703,8 +716,7 @@ public abstract class Server {
/** Reads calls from a connection and queues them for handling. */
private class Connection {
private boolean versionRead = false; //if initial signature and
//version are read
private boolean rpcHeaderRead = false; // if initial rpc header is read
private boolean headerRead = false; //if the connection header that
//follows version is read.
@ -723,6 +735,13 @@ public abstract class Server {
ConnectionHeader header = new ConnectionHeader();
Class<?> protocol;
boolean useSasl;
SaslServer saslServer;
private AuthMethod authMethod;
private boolean saslContextEstablished;
private ByteBuffer rpcHeaderBuffer;
private ByteBuffer unwrappedData;
private ByteBuffer unwrappedDataLengthBuffer;
UserGroupInformation user = null;
@ -731,6 +750,10 @@ public abstract class Server {
private final Call authFailedCall =
new Call(AUTHROIZATION_FAILED_CALLID, null, null);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
// Fake 'call' for SASL context setup
private static final int SASL_CALLID = -33;
private final Call saslCall = new Call(SASL_CALLID, null, null);
private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
public Connection(SelectionKey key, SocketChannel channel,
long lastContact) {
@ -738,6 +761,8 @@ public abstract class Server {
this.lastContact = lastContact;
this.data = null;
this.dataLengthBuffer = ByteBuffer.allocate(4);
this.unwrappedData = null;
this.unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
this.socket = channel.socket();
InetAddress addr = socket.getInetAddress();
if (addr == null) {
@ -795,6 +820,92 @@ public abstract class Server {
return false;
}
private void saslReadAndProcess(byte[] saslToken) throws IOException,
InterruptedException {
if (!saslContextEstablished) {
if (saslServer == null) {
switch (authMethod) {
case DIGEST:
saslServer = Sasl.createSaslServer(AuthMethod.DIGEST
.getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM,
SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler(
secretManager));
break;
default:
UserGroupInformation current = UserGroupInformation
.getCurrentUser();
String fullName = current.getUserName();
if (LOG.isDebugEnabled())
LOG.debug("Kerberos principal name is " + fullName);
final String names[] = SaslRpcServer.splitKerberosName(fullName);
if (names.length != 3) {
throw new IOException(
"Kerberos principal name does NOT have the expected "
+ "hostname part: " + fullName);
}
current.doAs(new PrivilegedExceptionAction<Object>() {
@Override
public Object run() throws IOException {
saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS
.getMechanismName(), names[0], names[1],
SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler());
return null;
}
});
}
if (saslServer == null)
throw new IOException(
"Unable to find SASL server implementation for "
+ authMethod.getMechanismName());
if (LOG.isDebugEnabled())
LOG.debug("Created SASL server with mechanism = "
+ authMethod.getMechanismName());
}
if (LOG.isDebugEnabled())
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.evaluateResponse()");
byte[] replyToken = saslServer.evaluateResponse(saslToken);
if (replyToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + replyToken.length
+ " from saslServer.");
saslCall.connection = this;
saslResponse.reset();
DataOutputStream out = new DataOutputStream(saslResponse);
out.writeInt(replyToken.length);
out.write(replyToken, 0, replyToken.length);
saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray()));
responder.doRespond(saslCall);
}
if (saslServer.isComplete()) {
if (LOG.isDebugEnabled()) {
LOG.debug("SASL server context established. Negotiated QoP is "
+ saslServer.getNegotiatedProperty(Sasl.QOP));
}
user = UserGroupInformation.createRemoteUser(saslServer
.getAuthorizationID());
LOG.info("SASL server successfully authenticated client: " + user);
saslContextEstablished = true;
}
} else {
if (LOG.isDebugEnabled())
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.unwrap()");
byte[] plaintextData = saslServer
.unwrap(saslToken, 0, saslToken.length);
processUnwrappedData(plaintextData);
}
}
private void disposeSasl() {
if (saslServer != null) {
try {
saslServer.dispose();
} catch (SaslException ignored) {
}
}
}
public int readAndProcess() throws IOException, InterruptedException {
while (true) {
/* Read at most one RPC. If the header is not read completely yet
@ -807,14 +918,33 @@ public abstract class Server {
return count;
}
if (!versionRead) {
if (!rpcHeaderRead) {
//Every connection is expected to send the header.
ByteBuffer versionBuffer = ByteBuffer.allocate(1);
count = channelRead(channel, versionBuffer);
if (count <= 0) {
if (rpcHeaderBuffer == null) {
rpcHeaderBuffer = ByteBuffer.allocate(2);
}
count = channelRead(channel, rpcHeaderBuffer);
if (count < 0 || rpcHeaderBuffer.remaining() > 0) {
return count;
}
int version = versionBuffer.get(0);
int version = rpcHeaderBuffer.get(0);
byte[] method = new byte[] {rpcHeaderBuffer.get(1)};
authMethod = AuthMethod.read(new DataInputStream(
new ByteArrayInputStream(method)));
if (authMethod == null) {
throw new IOException("Unable to read authentication method");
}
if (UserGroupInformation.isSecurityEnabled()
&& authMethod == AuthMethod.SIMPLE) {
throw new IOException("Authentication is required");
}
if (!UserGroupInformation.isSecurityEnabled()
&& authMethod != AuthMethod.SIMPLE) {
throw new IOException("Authentication is not supported");
}
if (authMethod != AuthMethod.SIMPLE) {
useSasl = true;
}
dataLengthBuffer.flip();
if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
@ -826,7 +956,8 @@ public abstract class Server {
return -1;
}
dataLengthBuffer.clear();
versionRead = true;
rpcHeaderBuffer = null;
rpcHeaderRead = true;
continue;
}
@ -834,12 +965,11 @@ public abstract class Server {
dataLengthBuffer.flip();
dataLength = dataLengthBuffer.getInt();
if (dataLength == Client.PING_CALL_ID) {
if (!useSasl && dataLength == Client.PING_CALL_ID) {
dataLengthBuffer.clear();
return 0; //ping message
}
data = ByteBuffer.allocate(dataLength);
incRpcCount(); // Increment the rpc count
}
count = channelRead(channel, data);
@ -847,33 +977,14 @@ public abstract class Server {
if (data.remaining() == 0) {
dataLengthBuffer.clear();
data.flip();
if (headerRead) {
processData();
data = null;
return count;
boolean isHeaderRead = headerRead;
if (useSasl) {
saslReadAndProcess(data.array());
} else {
processHeader();
headerRead = true;
data = null;
// Authorize the connection
try {
authorize(user, header);
if (LOG.isDebugEnabled()) {
LOG.debug("Successfully authorized " + header);
}
} catch (AuthorizationException ae) {
authFailedCall.connection = this;
setupResponse(authFailedResponse, authFailedCall,
Status.FATAL, null,
ae.getClass().getName(), ae.getMessage());
responder.doRespond(authFailedCall);
// Close this connection
return -1;
}
processOneRpc(data.array());
}
data = null;
if (!isHeaderRead) {
continue;
}
}
@ -882,9 +993,9 @@ public abstract class Server {
}
/// Reads the connection header following version
private void processHeader() throws IOException {
private void processHeader(byte[] buf) throws IOException {
DataInputStream in =
new DataInputStream(new ByteArrayInputStream(data.array()));
new DataInputStream(new ByteArrayInputStream(buf));
header.readFields(in);
try {
String protocolClassName = header.getProtocol();
@ -895,12 +1006,73 @@ public abstract class Server {
throw new IOException("Unknown protocol: " + header.getProtocol());
}
user = header.getUgi();
UserGroupInformation protocolUser = header.getUgi();
if (!useSasl) {
user = protocolUser;
} else if (protocolUser != null && !protocolUser.equals(user)) {
throw new AccessControlException("Authenticated user (" + user
+ ") doesn't match what the client claims to be (" + protocolUser
+ ")");
}
}
private void processData() throws IOException, InterruptedException {
private void processUnwrappedData(byte[] inBuf) throws IOException,
InterruptedException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
inBuf));
// Read all RPCs contained in the inBuf, even partial ones
while (true) {
int count = -1;
if (unwrappedDataLengthBuffer.remaining() > 0) {
count = channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0)
return;
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == Client.PING_CALL_ID) {
if (LOG.isDebugEnabled())
LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0)
return;
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(unwrappedData.array());
unwrappedData = null;
}
}
}
private void processOneRpc(byte[] buf) throws IOException,
InterruptedException {
if (headerRead) {
processData(buf);
} else {
processHeader(buf);
headerRead = true;
if (!authorizeConnection()) {
throw new AccessControlException("Connection from " + this
+ " for protocol " + header.getProtocol()
+ " is unauthorized for user " + user);
}
}
}
private void processData(byte[] buf) throws IOException, InterruptedException {
DataInputStream dis =
new DataInputStream(new ByteArrayInputStream(data.array()));
new DataInputStream(new ByteArrayInputStream(buf));
int id = dis.readInt(); // try to read an id
if (LOG.isDebugEnabled())
@ -911,9 +1083,27 @@ public abstract class Server {
Call call = new Call(id, param, this);
callQueue.put(call); // queue the call; maybe blocked here
incRpcCount(); // Increment the rpc count
}
private boolean authorizeConnection() throws IOException {
try {
authorize(user, header);
if (LOG.isDebugEnabled()) {
LOG.debug("Successfully authorized " + header);
}
} catch (AuthorizationException ae) {
authFailedCall.connection = this;
setupResponse(authFailedResponse, authFailedCall, Status.FATAL, null,
ae.getClass().getName(), ae.getMessage());
responder.doRespond(authFailedCall);
return false;
}
return true;
}
private synchronized void close() throws IOException {
disposeSasl();
data = null;
dataLengthBuffer = null;
if (!channel.isOpen())
@ -1011,16 +1201,17 @@ public abstract class Server {
Configuration conf)
throws IOException
{
this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port));
this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port), null);
}
/** Constructs a server listening on the named port and address. Parameters passed must
* be of the named class. The <code>handlerCount</handlerCount> determines
* the number of handler threads that will be used to process calls.
*
*/
@SuppressWarnings("unchecked")
protected Server(String bindAddress, int port,
Class<? extends Writable> paramClass, int handlerCount,
Configuration conf, String serverName)
Configuration conf, String serverName, SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
this.bindAddress = bindAddress;
this.conf = conf;
@ -1033,6 +1224,7 @@ public abstract class Server {
this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000);
this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10);
this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000);
this.secretManager = (SecretManager<TokenIdentifier>) secretManager;
this.authorize =
conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG,
false);
@ -1086,9 +1278,29 @@ public abstract class Server {
WritableUtils.writeString(out, errorClass);
WritableUtils.writeString(out, error);
}
wrapWithSasl(response, call);
call.setResponse(ByteBuffer.wrap(response.toByteArray()));
}
private void wrapWithSasl(ByteArrayOutputStream response, Call call)
throws IOException {
if (call.connection.useSasl) {
byte[] token = response.toByteArray();
// synchronization may be needed since there can be multiple Handler
// threads using saslServer to wrap responses.
synchronized (call.connection.saslServer) {
token = call.connection.saslServer.wrap(token, 0, token.length);
}
if (LOG.isDebugEnabled())
LOG.debug("Adding saslServer wrapped token of size " + token.length
+ " as call response.");
response.reset();
DataOutputStream saslOut = new DataOutputStream(response);
saslOut.writeInt(token.length);
saslOut.write(token, 0, token.length);
}
}
Configuration getConf() {
return conf;
}

View File

@ -36,6 +36,8 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.io.*;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
@ -246,7 +248,7 @@ class WritableRpcEngine implements RpcEngine {
public Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf)
throws IOException {
throws IOException, InterruptedException {
Invocation[] invocations = new Invocation[params.length];
for (int i = 0; i < params.length; i++)
@ -276,9 +278,11 @@ class WritableRpcEngine implements RpcEngine {
* port and address. */
public Server getServer(Class protocol,
Object instance, String bindAddress, int port,
int numHandlers, boolean verbose, Configuration conf)
int numHandlers, boolean verbose, Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
return new Server(instance, conf, bindAddress, port, numHandlers, verbose);
return new Server(instance, conf, bindAddress, port, numHandlers,
verbose, secretManager);
}
/** An RPC Server. */
@ -294,7 +298,7 @@ class WritableRpcEngine implements RpcEngine {
*/
public Server(Object instance, Configuration conf, String bindAddress, int port)
throws IOException {
this(instance, conf, bindAddress, port, 1, false);
this(instance, conf, bindAddress, port, 1, false, null);
}
private static String classNameBase(String className) {
@ -314,8 +318,11 @@ class WritableRpcEngine implements RpcEngine {
* @param verbose whether each call should be logged
*/
public Server(Object instance, Configuration conf, String bindAddress, int port,
int numHandlers, boolean verbose) throws IOException {
super(bindAddress, port, Invocation.class, numHandlers, conf, classNameBase(instance.getClass().getName()));
int numHandlers, boolean verbose,
SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
super(bindAddress, port, Invocation.class, numHandlers, conf,
classNameBase(instance.getClass().getName()), secretManager);
this.instance = instance;
this.verbose = verbose;
}

View File

@ -0,0 +1,31 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.lang.annotation.*;
/**
* Indicates Kerberos related information to be used
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface KerberosInfo {
/** Key for getting server's Kerberos principal name from Configuration */
String value();
}

View File

@ -0,0 +1,321 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.InputStream;
import java.io.IOException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* A SaslInputStream is composed of an InputStream and a SaslServer (or
* SaslClient) so that read() methods return data that are read in from the
* underlying InputStream but have been additionally processed by the SaslServer
* (or SaslClient) object. The SaslServer (or SaslClient) object must be fully
* initialized before being used by a SaslInputStream.
*/
public class SaslInputStream extends InputStream {
public static final Log LOG = LogFactory.getLog(SaslInputStream.class);
private final DataInputStream inStream;
/*
* data read from the underlying input stream before being processed by SASL
*/
private byte[] saslToken;
private final SaslClient saslClient;
private final SaslServer saslServer;
private byte[] lengthBuf = new byte[4];
/*
* buffer holding data that have been processed by SASL, but have not been
* read out
*/
private byte[] obuffer;
// position of the next "new" byte
private int ostart = 0;
// position of the last "new" byte
private int ofinish = 0;
private static int unsignedBytesToInt(byte[] buf) {
if (buf.length != 4) {
throw new IllegalArgumentException(
"Cannot handle byte array other than 4 bytes");
}
int result = 0;
for (int i = 0; i < 4; i++) {
result <<= 8;
result |= ((int) buf[i] & 0xff);
}
return result;
}
/**
* Read more data and get them processed <br>
* Entry condition: ostart = ofinish <br>
* Exit condition: ostart <= ofinish <br>
*
* return (ofinish-ostart) (we have this many bytes for you), 0 (no data now,
* but could have more later), or -1 (absolutely no more data)
*/
private int readMoreData() throws IOException {
try {
inStream.readFully(lengthBuf);
int length = unsignedBytesToInt(lengthBuf);
if (LOG.isDebugEnabled())
LOG.debug("Actual length is " + length);
saslToken = new byte[length];
inStream.readFully(saslToken);
} catch (EOFException e) {
return -1;
}
try {
if (saslServer != null) { // using saslServer
obuffer = saslServer.unwrap(saslToken, 0, saslToken.length);
} else { // using saslClient
obuffer = saslClient.unwrap(saslToken, 0, saslToken.length);
}
} catch (SaslException se) {
try {
disposeSasl();
} catch (SaslException ignored) {
}
throw se;
}
ostart = 0;
if (obuffer == null)
ofinish = 0;
else
ofinish = obuffer.length;
return ofinish;
}
/**
* Disposes of any system resources or security-sensitive information Sasl
* might be using.
*
* @exception SaslException
* if a SASL error occurs.
*/
private void disposeSasl() throws SaslException {
if (saslClient != null) {
saslClient.dispose();
}
if (saslServer != null) {
saslServer.dispose();
}
}
/**
* Constructs a SASLInputStream from an InputStream and a SaslServer <br>
* Note: if the specified InputStream or SaslServer is null, a
* NullPointerException may be thrown later when they are used.
*
* @param inStream
* the InputStream to be processed
* @param saslServer
* an initialized SaslServer object
*/
public SaslInputStream(InputStream inStream, SaslServer saslServer) {
this.inStream = new DataInputStream(inStream);
this.saslServer = saslServer;
this.saslClient = null;
}
/**
* Constructs a SASLInputStream from an InputStream and a SaslClient <br>
* Note: if the specified InputStream or SaslClient is null, a
* NullPointerException may be thrown later when they are used.
*
* @param inStream
* the InputStream to be processed
* @param saslClient
* an initialized SaslClient object
*/
public SaslInputStream(InputStream inStream, SaslClient saslClient) {
this.inStream = new DataInputStream(inStream);
this.saslServer = null;
this.saslClient = saslClient;
}
/**
* Reads the next byte of data from this input stream. The value byte is
* returned as an <code>int</code> in the range <code>0</code> to
* <code>255</code>. If no byte is available because the end of the stream has
* been reached, the value <code>-1</code> is returned. This method blocks
* until input data is available, the end of the stream is detected, or an
* exception is thrown.
* <p>
*
* @return the next byte of data, or <code>-1</code> if the end of the stream
* is reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read() throws IOException {
if (ostart >= ofinish) {
// we loop for new data as we are blocking
int i = 0;
while (i == 0)
i = readMoreData();
if (i == -1)
return -1;
}
return ((int) obuffer[ostart++] & 0xff);
}
/**
* Reads up to <code>b.length</code> bytes of data from this input stream into
* an array of bytes.
* <p>
* The <code>read</code> method of <code>InputStream</code> calls the
* <code>read</code> method of three arguments with the arguments
* <code>b</code>, <code>0</code>, and <code>b.length</code>.
*
* @param b
* the buffer into which the data is read.
* @return the total number of bytes read into the buffer, or <code>-1</code>
* is there is no more data because the end of the stream has been
* reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
/**
* Reads up to <code>len</code> bytes of data from this input stream into an
* array of bytes. This method blocks until some input is available. If the
* first argument is <code>null,</code> up to <code>len</code> bytes are read
* and discarded.
*
* @param b
* the buffer into which the data is read.
* @param off
* the start offset of the data.
* @param len
* the maximum number of bytes read.
* @return the total number of bytes read into the buffer, or <code>-1</code>
* if there is no more data because the end of the stream has been
* reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read(byte[] b, int off, int len) throws IOException {
if (ostart >= ofinish) {
// we loop for new data as we are blocking
int i = 0;
while (i == 0)
i = readMoreData();
if (i == -1)
return -1;
}
if (len <= 0) {
return 0;
}
int available = ofinish - ostart;
if (len < available)
available = len;
if (b != null) {
System.arraycopy(obuffer, ostart, b, off, available);
}
ostart = ostart + available;
return available;
}
/**
* Skips <code>n</code> bytes of input from the bytes that can be read from
* this input stream without blocking.
*
* <p>
* Fewer bytes than requested might be skipped. The actual number of bytes
* skipped is equal to <code>n</code> or the result of a call to
* {@link #available() <code>available</code>}, whichever is smaller. If
* <code>n</code> is less than zero, no bytes are skipped.
*
* <p>
* The actual number of bytes skipped is returned.
*
* @param n
* the number of bytes to be skipped.
* @return the actual number of bytes skipped.
* @exception IOException
* if an I/O error occurs.
*/
public long skip(long n) throws IOException {
int available = ofinish - ostart;
if (n > available) {
n = available;
}
if (n < 0) {
return 0;
}
ostart += n;
return n;
}
/**
* Returns the number of bytes that can be read from this input stream without
* blocking. The <code>available</code> method of <code>InputStream</code>
* returns <code>0</code>. This method <B>should</B> be overridden by
* subclasses.
*
* @return the number of bytes that can be read from this input stream without
* blocking.
* @exception IOException
* if an I/O error occurs.
*/
public int available() throws IOException {
return (ofinish - ostart);
}
/**
* Closes this input stream and releases any system resources associated with
* the stream.
* <p>
* The <code>close</code> method of <code>SASLInputStream</code> calls the
* <code>close</code> method of its underlying input stream.
*
* @exception IOException
* if an I/O error occurs.
*/
public void close() throws IOException {
disposeSasl();
ostart = 0;
ofinish = 0;
inStream.close();
}
/**
* Tests if this input stream supports the <code>mark</code> and
* <code>reset</code> methods, which it does not.
*
* @return <code>false</code>, since this class does not support the
* <code>mark</code> and <code>reset</code> methods.
*/
public boolean markSupported() {
return false;
}
}

View File

@ -0,0 +1,181 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
/**
* A SaslOutputStream is composed of an OutputStream and a SaslServer (or
* SaslClient) so that write() methods first process the data before writing
* them out to the underlying OutputStream. The SaslServer (or SaslClient)
* object must be fully initialized before being used by a SaslOutputStream.
*/
public class SaslOutputStream extends OutputStream {
private final DataOutputStream outStream;
// processed data ready to be written out
private byte[] saslToken;
private final SaslClient saslClient;
private final SaslServer saslServer;
// buffer holding one byte of incoming data
private final byte[] ibuffer = new byte[1];
/**
* Constructs a SASLOutputStream from an OutputStream and a SaslServer <br>
* Note: if the specified OutputStream or SaslServer is null, a
* NullPointerException may be thrown later when they are used.
*
* @param outStream
* the OutputStream to be processed
* @param saslServer
* an initialized SaslServer object
*/
public SaslOutputStream(OutputStream outStream, SaslServer saslServer) {
this.outStream = new DataOutputStream(outStream);
this.saslServer = saslServer;
this.saslClient = null;
}
/**
* Constructs a SASLOutputStream from an OutputStream and a SaslClient <br>
* Note: if the specified OutputStream or SaslClient is null, a
* NullPointerException may be thrown later when they are used.
*
* @param outStream
* the OutputStream to be processed
* @param saslClient
* an initialized SaslClient object
*/
public SaslOutputStream(OutputStream outStream, SaslClient saslClient) {
this.outStream = new DataOutputStream(outStream);
this.saslServer = null;
this.saslClient = saslClient;
}
/**
* Disposes of any system resources or security-sensitive information Sasl
* might be using.
*
* @exception SaslException
* if a SASL error occurs.
*/
private void disposeSasl() throws SaslException {
if (saslClient != null) {
saslClient.dispose();
}
if (saslServer != null) {
saslServer.dispose();
}
}
/**
* Writes the specified byte to this output stream.
*
* @param b
* the <code>byte</code>.
* @exception IOException
* if an I/O error occurs.
*/
public void write(int b) throws IOException {
ibuffer[0] = (byte) b;
write(ibuffer, 0, 1);
}
/**
* Writes <code>b.length</code> bytes from the specified byte array to this
* output stream.
* <p>
* The <code>write</code> method of <code>SASLOutputStream</code> calls the
* <code>write</code> method of three arguments with the three arguments
* <code>b</code>, <code>0</code>, and <code>b.length</code>.
*
* @param b
* the data.
* @exception NullPointerException
* if <code>b</code> is null.
* @exception IOException
* if an I/O error occurs.
*/
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
/**
* Writes <code>len</code> bytes from the specified byte array starting at
* offset <code>off</code> to this output stream.
*
* @param inBuf
* the data.
* @param off
* the start offset in the data.
* @param len
* the number of bytes to write.
* @exception IOException
* if an I/O error occurs.
*/
public void write(byte[] inBuf, int off, int len) throws IOException {
try {
if (saslServer != null) { // using saslServer
saslToken = saslServer.wrap(inBuf, off, len);
} else { // using saslClient
saslToken = saslClient.wrap(inBuf, off, len);
}
} catch (SaslException se) {
try {
disposeSasl();
} catch (SaslException ignored) {
}
throw se;
}
if (saslToken != null) {
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
saslToken = null;
}
}
/**
* Flushes this output stream
*
* @exception IOException
* if an I/O error occurs.
*/
public void flush() throws IOException {
outStream.flush();
}
/**
* Closes this output stream and releases any system resources associated with
* this stream.
*
* @exception IOException
* if an I/O error occurs.
*/
public void close() throws IOException {
disposeSasl();
outStream.close();
}
}

View File

@ -0,0 +1,249 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslClient;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
/**
* A utility class that encapsulates SASL logic for RPC client
*/
public class SaslRpcClient {
public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
private final SaslClient saslClient;
/**
* Create a SaslRpcClient for an authentication method
*
* @param method
* the requested authentication method
* @param token
* token to use if needed by the authentication method
*/
public SaslRpcClient(AuthMethod method,
Token<? extends TokenIdentifier> token, String serverPrincipal)
throws IOException {
switch (method) {
case DIGEST:
if (LOG.isDebugEnabled())
LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
+ " client to authenticate to service at " + token.getService());
saslClient = Sasl.createSaslClient(new String[] { AuthMethod.DIGEST
.getMechanismName() }, null, null, SaslRpcServer.SASL_DEFAULT_REALM,
SaslRpcServer.SASL_PROPS, new SaslClientCallbackHandler(token));
break;
case KERBEROS:
if (LOG.isDebugEnabled()) {
LOG
.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
+ " client. Server's Kerberos principal name is "
+ serverPrincipal);
}
if (serverPrincipal == null || serverPrincipal.length() == 0) {
throw new IOException(
"Failed to specify server's Kerberos principal name");
}
String names[] = SaslRpcServer.splitKerberosName(serverPrincipal);
if (names.length != 3) {
throw new IOException(
"Kerberos principal name does NOT have the expected hostname part: "
+ serverPrincipal);
}
saslClient = Sasl.createSaslClient(new String[] { AuthMethod.KERBEROS
.getMechanismName() }, null, names[0], names[1],
SaslRpcServer.SASL_PROPS, null);
break;
default:
throw new IOException("Unknown authentication method " + method);
}
if (saslClient == null)
throw new IOException("Unable to find SASL client implementation");
}
/**
* Do client side SASL authentication with server via the given InputStream
* and OutputStream
*
* @param inS
* InputStream to use
* @param outS
* OutputStream to use
* @throws IOException
*/
public void saslConnect(InputStream inS, OutputStream outS)
throws IOException {
DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(
outS));
try {
byte[] saslToken = new byte[0];
if (saslClient.hasInitialResponse())
saslToken = saslClient.evaluateChallenge(saslToken);
if (saslToken != null) {
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
outStream.flush();
if (LOG.isDebugEnabled())
LOG.debug("Have sent token of size " + saslToken.length
+ " from initSASLContext.");
}
if (!saslClient.isComplete()) {
saslToken = new byte[inStream.readInt()];
if (LOG.isDebugEnabled())
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
inStream.readFully(saslToken);
}
while (!saslClient.isComplete()) {
saslToken = saslClient.evaluateChallenge(saslToken);
if (saslToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + saslToken.length
+ " from initSASLContext.");
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
outStream.flush();
}
if (!saslClient.isComplete()) {
saslToken = new byte[inStream.readInt()];
if (LOG.isDebugEnabled())
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
inStream.readFully(saslToken);
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("SASL client context established. Negotiated QoP: "
+ saslClient.getNegotiatedProperty(Sasl.QOP));
}
} catch (IOException e) {
saslClient.dispose();
throw e;
}
}
/**
* Get a SASL wrapped InputStream. Can be called only after saslConnect() has
* been called.
*
* @param in
* the InputStream to wrap
* @return a SASL wrapped InputStream
* @throws IOException
*/
public InputStream getInputStream(InputStream in) throws IOException {
if (!saslClient.isComplete()) {
throw new IOException("Sasl authentication exchange hasn't completed yet");
}
return new SaslInputStream(in, saslClient);
}
/**
* Get a SASL wrapped OutputStream. Can be called only after saslConnect() has
* been called.
*
* @param out
* the OutputStream to wrap
* @return a SASL wrapped OutputStream
* @throws IOException
*/
public OutputStream getOutputStream(OutputStream out) throws IOException {
if (!saslClient.isComplete()) {
throw new IOException("Sasl authentication exchange hasn't completed yet");
}
return new SaslOutputStream(out, saslClient);
}
/** Release resources used by wrapped saslClient */
public void dispose() throws SaslException {
saslClient.dispose();
}
private static class SaslClientCallbackHandler implements CallbackHandler {
private final String userName;
private final char[] userPassword;
public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) {
this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier());
this.userPassword = SaslRpcServer.encodePassword(token.getPassword());
}
public void handle(Callback[] callbacks)
throws UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
RealmCallback rc = null;
for (Callback callback : callbacks) {
if (callback instanceof RealmChoiceCallback) {
continue;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
rc = (RealmCallback) callback;
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL client callback");
}
}
if (nc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting username: " + userName);
nc.setName(userName);
}
if (pc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting userPassword");
pc.setPassword(userPassword);
}
if (rc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting realm: "
+ rc.getDefaultText());
rc.setText(rc.getDefaultText());
}
}
}
}

View File

@ -0,0 +1,218 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import java.util.TreeMap;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
/**
* A utility class for dealing with SASL on RPC server
*/
public class SaslRpcServer {
public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
public static final String SASL_DEFAULT_REALM = "default";
public static final Map<String, String> SASL_PROPS =
new TreeMap<String, String>();
static {
// Request authentication plus integrity protection
SASL_PROPS.put(Sasl.QOP, "auth-int");
// Request mutual authentication
SASL_PROPS.put(Sasl.SERVER_AUTH, "true");
}
static String encodeIdentifier(byte[] identifier) {
return new String(Base64.encodeBase64(identifier));
}
static byte[] decodeIdentifier(String identifier) {
return Base64.decodeBase64(identifier.getBytes());
}
static char[] encodePassword(byte[] password) {
return new String(Base64.encodeBase64(password)).toCharArray();
}
/** Splitting fully qualified Kerberos name into parts */
public static String[] splitKerberosName(String fullName) {
return fullName.split("[/@]");
}
/** Authentication method */
public static enum AuthMethod {
SIMPLE((byte) 80, ""), // no authentication
KERBEROS((byte) 81, "GSSAPI"), // SASL Kerberos authentication
DIGEST((byte) 82, "DIGEST-MD5"); // SASL DIGEST-MD5 authentication
/** The code for this method. */
public final byte code;
public final String mechanismName;
private AuthMethod(byte code, String mechanismName) {
this.code = code;
this.mechanismName = mechanismName;
}
private static final int FIRST_CODE = values()[0].code;
/** Return the object represented by the code. */
private static AuthMethod valueOf(byte code) {
final int i = (code & 0xff) - FIRST_CODE;
return i < 0 || i >= values().length ? null : values()[i];
}
/** Return the SASL mechanism name */
public String getMechanismName() {
return mechanismName;
}
/** Read from in */
public static AuthMethod read(DataInput in) throws IOException {
return valueOf(in.readByte());
}
/** Write to out */
public void write(DataOutput out) throws IOException {
out.write(code);
}
};
/** CallbackHandler for SASL DIGEST-MD5 mechanism */
public static class SaslDigestCallbackHandler implements CallbackHandler {
private SecretManager<TokenIdentifier> secretManager;
public SaslDigestCallbackHandler(
SecretManager<TokenIdentifier> secretManager) {
this.secretManager = secretManager;
}
private TokenIdentifier getIdentifier(String id) throws IOException {
byte[] tokenId = decodeIdentifier(id);
TokenIdentifier tokenIdentifier = secretManager.createIdentifier();
tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
tokenId)));
return tokenIdentifier;
}
private char[] getPassword(TokenIdentifier tokenid) throws IOException {
return encodePassword(secretManager.retrievePassword(tokenid));
}
/** {@inheritDoc} */
@Override
public void handle(Callback[] callbacks) throws IOException,
UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
AuthorizeCallback ac = null;
for (Callback callback : callbacks) {
if (callback instanceof AuthorizeCallback) {
ac = (AuthorizeCallback) callback;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
continue; // realm is ignored
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL DIGEST-MD5 Callback");
}
}
if (pc != null) {
TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName());
char[] password = getPassword(tokenIdentifier);
if (LOG.isDebugEnabled()) {
LOG.debug("SASL server DIGEST-MD5 callback: setting password "
+ "for client: " + tokenIdentifier.getUsername());
}
pc.setPassword(password);
}
if (ac != null) {
String authid = ac.getAuthenticationID();
String authzid = ac.getAuthorizationID();
if (authid.equals(authzid)) {
ac.setAuthorized(true);
} else {
ac.setAuthorized(false);
}
if (ac.isAuthorized()) {
String username = getIdentifier(authzid).getUsername().toString();
if (LOG.isDebugEnabled())
LOG.debug("SASL server DIGEST-MD5 callback: setting "
+ "canonicalized client ID: " + username);
ac.setAuthorizedID(username);
}
}
}
}
/** CallbackHandler for SASL GSSAPI Kerberos mechanism */
public static class SaslGssCallbackHandler implements CallbackHandler {
/** {@inheritDoc} */
@Override
public void handle(Callback[] callbacks) throws IOException,
UnsupportedCallbackException {
AuthorizeCallback ac = null;
for (Callback callback : callbacks) {
if (callback instanceof AuthorizeCallback) {
ac = (AuthorizeCallback) callback;
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL GSSAPI Callback");
}
}
if (ac != null) {
String authid = ac.getAuthenticationID();
String authzid = ac.getAuthorizationID();
if (authid.equals(authzid)) {
ac.setAuthorized(true);
} else {
ac.setAuthorized(false);
}
if (ac.isAuthorized()) {
if (LOG.isDebugEnabled())
LOG.debug("SASL server GSSAPI callback: setting "
+ "canonicalized client ID: " + authzid);
ac.setAuthorizedID(authzid);
}
}
}
}
}

View File

@ -61,6 +61,12 @@ public abstract class SecretManager<T extends TokenIdentifier> {
*/
public abstract byte[] retrievePassword(T identifier) throws InvalidToken;
/**
* Create an empty token identifier.
* @return the newly created empty token identifier
*/
public abstract T createIdentifier();
/**
* The name of the hashing algorithm.
*/

View File

@ -36,6 +36,12 @@ public abstract class TokenIdentifier implements Writable {
*/
public abstract Text getKind();
/**
* Get the username encoded in the token identifier
* @return the username
*/
public abstract Text getUsername();
/**
* Get the bytes for the token identifier
* @return the bytes of the identifier

View File

@ -0,0 +1,31 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security.token;
import java.lang.annotation.*;
/**
* Indicates Token related information to be used
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface TokenInfo {
/** The type of TokenSelector to be used */
Class<? extends TokenSelector<? extends TokenIdentifier>> value();
}

View File

@ -0,0 +1,34 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security.token;
import java.util.Collection;
import org.apache.hadoop.io.Text;
/**
* Select token of type T from tokens for use with named service
*
* @param <T>
* T extends TokenIdentifier
*/
public interface TokenSelector<T extends TokenIdentifier> {
Token<T> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens);
}

View File

@ -63,4 +63,10 @@
<description>The name of the s3n file system for testing.</description>
</property>
<!-- Turn security off for tests by default -->
<property>
<name>hadoop.security.authentication</name>
<value>simple</value>
</property>
</configuration>

View File

@ -68,7 +68,7 @@ public class TestRPC extends TestCase {
int[] exchange(int[] values) throws IOException;
}
public class TestImpl implements TestProtocol {
public static class TestImpl implements TestProtocol {
int fastPingCounter = 0;
public long getProtocolVersion(String protocol, long clientVersion) {
@ -189,7 +189,7 @@ public class TestRPC extends TestCase {
System.out.println("Testing Slow RPC");
// create a server with two handlers
Server server = RPC.getServer(TestProtocol.class,
new TestImpl(), ADDRESS, 0, 2, false, conf);
new TestImpl(), ADDRESS, 0, 2, false, conf, null);
TestProtocol proxy = null;
try {
@ -339,7 +339,7 @@ public class TestRPC extends TestCase {
ServiceAuthorizationManager.refresh(conf, new TestPolicyProvider());
Server server = RPC.getServer(TestProtocol.class,
new TestImpl(), ADDRESS, 0, 5, true, conf);
new TestImpl(), ADDRESS, 0, 5, true, conf, null);
TestProtocol proxy = null;

View File

@ -0,0 +1,216 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collection;
import org.apache.commons.logging.*;
import org.apache.commons.logging.impl.Log4JLogger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.security.SaslInputStream;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.log4j.Level;
import org.junit.Test;
/** Unit tests for using Sasl over RPC. */
public class TestSaslRPC {
private static final String ADDRESS = "0.0.0.0";
public static final Log LOG =
LogFactory.getLog(TestSaslRPC.class);
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
private static Configuration conf;
static {
conf = new Configuration();
conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos");
UserGroupInformation.setConfiguration(conf);
}
static {
((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL);
}
public static class TestTokenIdentifier extends TokenIdentifier {
private Text tokenid;
final static Text KIND_NAME = new Text("test.token");
public TestTokenIdentifier() {
this.tokenid = new Text();
}
public TestTokenIdentifier(Text tokenid) {
this.tokenid = tokenid;
}
@Override
public Text getKind() {
return KIND_NAME;
}
@Override
public Text getUsername() {
return tokenid;
}
@Override
public void readFields(DataInput in) throws IOException {
tokenid.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
tokenid.write(out);
}
}
public static class TestTokenSecretManager extends
SecretManager<TestTokenIdentifier> {
public byte[] createPassword(TestTokenIdentifier id) {
return id.getBytes();
}
public byte[] retrievePassword(TestTokenIdentifier id)
throws InvalidToken {
return id.getBytes();
}
public TestTokenIdentifier createIdentifier() {
return new TestTokenIdentifier();
}
}
public static class TestTokenSelector implements
TokenSelector<TestTokenIdentifier> {
@SuppressWarnings("unchecked")
@Override
public Token<TestTokenIdentifier> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens) {
if (service == null) {
return null;
}
for (Token<? extends TokenIdentifier> token : tokens) {
if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
&& service.equals(token.getService())) {
return (Token<TestTokenIdentifier>) token;
}
}
return null;
}
}
@KerberosInfo(SERVER_PRINCIPAL_KEY)
@TokenInfo(TestTokenSelector.class)
public interface TestSaslProtocol extends TestRPC.TestProtocol {
}
public static class TestSaslImpl extends TestRPC.TestImpl implements
TestSaslProtocol {
}
@Test
public void testDigestRpc() throws Exception {
TestTokenSecretManager sm = new TestTokenSecretManager();
final Server server = RPC.getServer(TestSaslProtocol.class,
new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
server.start();
final UserGroupInformation current = UserGroupInformation.getCurrentUser();
final InetSocketAddress addr = NetUtils.getConnectAddress(server);
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
.getUserName()));
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
sm);
Text host = new Text(addr.getAddress().getHostAddress() + ":"
+ addr.getPort());
token.setService(host);
LOG.info("Service IP address for token is " + host);
current.addToken(token);
TestSaslProtocol proxy = null;
try {
proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, conf);
proxy.ping();
} finally {
server.stop();
if (proxy != null) {
RPC.stopProxy(proxy);
}
}
}
static void testKerberosRpc(String principal, String keytab) throws Exception {
final Configuration newConf = new Configuration(conf);
newConf.set(SERVER_PRINCIPAL_KEY, principal);
UserGroupInformation.loginUserFromKeytab(principal, keytab);
UserGroupInformation current = UserGroupInformation.getCurrentUser();
System.out.println("UGI: " + current);
Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(),
ADDRESS, 0, 5, true, newConf, null);
TestSaslProtocol proxy = null;
server.start();
InetSocketAddress addr = NetUtils.getConnectAddress(server);
try {
proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
proxy.ping();
} finally {
server.stop();
if (proxy != null) {
RPC.stopProxy(proxy);
}
}
}
public static void main(String[] args) throws Exception {
System.out.println("Testing Kerberos authentication over RPC");
if (args.length != 2) {
System.err
.println("Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC "
+ " <serverPrincipal> <keytabFile>");
System.exit(-1);
}
String principal = args[0];
String keytab = args[1];
testKerberosRpc(principal, keytab);
}
}