HADOOP-6419. Adds SASL based authentication to RPC. Contributed by Kan Zhang.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@905860 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Devaraj Das 2010-02-03 01:30:25 +00:00
parent fe0ddc03e1
commit 940389afce
19 changed files with 1713 additions and 80 deletions

View File

@ -48,6 +48,9 @@ Trunk (unreleased changes)
upon login. The tokens are read from a file specified in the upon login. The tokens are read from a file specified in the
environment variable. (ddas) environment variable. (ddas)
HADOOP-6419. Adds SASL based authentication to RPC.
(Kan Zhang via ddas)
IMPROVEMENTS IMPROVEMENTS
HADOOP-6283. Improve the exception messages thrown by HADOOP-6283. Improve the exception messages thrown by

View File

@ -33,6 +33,8 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.avro.*; import org.apache.avro.*;
@ -192,10 +194,13 @@ class AvroRpcEngine implements RpcEngine {
* port and address. */ * port and address. */
public RPC.Server getServer(Class iface, Object impl, String bindAddress, public RPC.Server getServer(Class iface, Object impl, String bindAddress,
int port, int numHandlers, boolean verbose, int port, int numHandlers, boolean verbose,
Configuration conf) throws IOException { Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager
) throws IOException {
return ENGINE.getServer(TunnelProtocol.class, return ENGINE.getServer(TunnelProtocol.class,
new TunnelResponder(iface, impl), new TunnelResponder(iface, impl),
bindAddress, port, numHandlers, verbose, conf); bindAddress, port, numHandlers, verbose, conf,
secretManager);
} }
} }

View File

@ -31,7 +31,9 @@ import java.io.BufferedInputStream;
import java.io.BufferedOutputStream; import java.io.BufferedOutputStream;
import java.io.FilterInputStream; import java.io.FilterInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream;
import java.security.PrivilegedExceptionAction;
import java.util.Hashtable; import java.util.Hashtable;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map.Entry; import java.util.Map.Entry;
@ -44,11 +46,19 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.ReflectionUtils;
/** A client for an IPC service. IPC calls take a single {@link Writable} as a /** A client for an IPC service. IPC calls take a single {@link Writable} as a
@ -196,8 +206,13 @@ public class Client {
* socket: responses may be delivered out of order. */ * socket: responses may be delivered out of order. */
private class Connection extends Thread { private class Connection extends Thread {
private InetSocketAddress server; // server ip:port private InetSocketAddress server; // server ip:port
private String serverPrincipal; // server's krb5 principal name
private ConnectionHeader header; // connection header private ConnectionHeader header; // connection header
private ConnectionId remoteId; // connection id private final ConnectionId remoteId; // connection id
private final AuthMethod authMethod; // authentication method
private final boolean useSasl;
private Token<? extends TokenIdentifier> token;
private SaslRpcClient saslRpcClient;
private Socket socket = null; // connected socket private Socket socket = null; // connected socket
private DataInputStream in; private DataInputStream in;
@ -221,6 +236,42 @@ public class Client {
Class<?> protocol = remoteId.getProtocol(); Class<?> protocol = remoteId.getProtocol();
header = header =
new ConnectionHeader(protocol == null ? null : protocol.getName(), ticket); new ConnectionHeader(protocol == null ? null : protocol.getName(), ticket);
this.useSasl = UserGroupInformation.isSecurityEnabled();
if (useSasl && protocol != null) {
TokenInfo tokenInfo = protocol.getAnnotation(TokenInfo.class);
if (tokenInfo != null) {
TokenSelector<? extends TokenIdentifier> tokenSelector = null;
try {
tokenSelector = tokenInfo.value().newInstance();
} catch (InstantiationException e) {
throw new IOException(e.toString());
} catch (IllegalAccessException e) {
throw new IOException(e.toString());
}
InetSocketAddress addr = remoteId.getAddress();
token = tokenSelector.selectToken(new Text(addr.getAddress()
.getHostAddress() + ":" + addr.getPort()),
ticket.getTokens());
}
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
if (krbInfo != null) {
String serverKey = krbInfo.value();
if (serverKey != null) {
serverPrincipal = conf.get(serverKey);
}
}
}
if (!useSasl) {
authMethod = AuthMethod.SIMPLE;
} else if (token != null) {
authMethod = AuthMethod.DIGEST;
} else {
authMethod = AuthMethod.KERBEROS;
}
if (LOG.isDebugEnabled())
LOG.debug("Use " + authMethod + " authentication for protocol "
+ protocol.getSimpleName());
this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " + this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " +
remoteId.getAddress().toString() + remoteId.getAddress().toString() +
@ -302,11 +353,20 @@ public class Client {
} }
} }
private synchronized void disposeSasl() {
if (saslRpcClient != null) {
try {
saslRpcClient.dispose();
} catch (IOException ignored) {
}
}
}
/** Connect to the server and set up the I/O streams. It then sends /** Connect to the server and set up the I/O streams. It then sends
* a header to the server and starts * a header to the server and starts
* the connection thread that waits for responses. * the connection thread that waits for responses.
*/ */
private synchronized void setupIOstreams() { private synchronized void setupIOstreams() throws InterruptedException {
if (socket != null || shouldCloseConnection.get()) { if (socket != null || shouldCloseConnection.get()) {
return; return;
} }
@ -334,15 +394,33 @@ public class Client {
handleConnectionFailure(ioFailures++, maxRetries, ie); handleConnectionFailure(ioFailures++, maxRetries, ie);
} }
} }
InputStream inStream = NetUtils.getInputStream(socket);
OutputStream outStream = NetUtils.getOutputStream(socket);
writeRpcHeader(outStream);
if (useSasl) {
final InputStream in2 = inStream;
final OutputStream out2 = outStream;
remoteId.getTicket().doAs(new PrivilegedExceptionAction<Object>() {
@Override
public Object run() throws IOException {
saslRpcClient = new SaslRpcClient(authMethod, token,
serverPrincipal);
saslRpcClient.saslConnect(in2, out2);
return null;
}
});
inStream = saslRpcClient.getInputStream(inStream);
outStream = saslRpcClient.getOutputStream(outStream);
}
if (doPing) { if (doPing) {
this.in = new DataInputStream(new BufferedInputStream this.in = new DataInputStream(new BufferedInputStream
(new PingInputStream(NetUtils.getInputStream(socket)))); (new PingInputStream(inStream)));
} else { } else {
this.in = new DataInputStream(new BufferedInputStream this.in = new DataInputStream(new BufferedInputStream
(NetUtils.getInputStream(socket))); (inStream));
} }
this.out = new DataOutputStream this.out = new DataOutputStream
(new BufferedOutputStream(NetUtils.getOutputStream(socket))); (new BufferedOutputStream(outStream));
writeHeader(); writeHeader();
// update last activity time // update last activity time
@ -396,14 +474,20 @@ public class Client {
". Already tried " + curRetries + " time(s)."); ". Already tried " + curRetries + " time(s).");
} }
/* Write the header for each connection /* Write the RPC header */
private void writeRpcHeader(OutputStream outStream) throws IOException {
DataOutputStream out = new DataOutputStream(new BufferedOutputStream(outStream));
// Write out the header, version and authentication method
out.write(Server.HEADER.array());
out.write(Server.CURRENT_VERSION);
authMethod.write(out);
out.flush();
}
/* Write the protocol header for each connection
* Out is not synchronized because only the first thread does this. * Out is not synchronized because only the first thread does this.
*/ */
private void writeHeader() throws IOException { private void writeHeader() throws IOException {
// Write out the header and version
out.write(Server.HEADER.array());
out.write(Server.CURRENT_VERSION);
// Write out the ConnectionHeader // Write out the ConnectionHeader
DataOutputBuffer buf = new DataOutputBuffer(); DataOutputBuffer buf = new DataOutputBuffer();
header.write(buf); header.write(buf);
@ -575,6 +659,7 @@ public class Client {
// close the streams and therefore the socket // close the streams and therefore the socket
IOUtils.closeStream(out); IOUtils.closeStream(out);
IOUtils.closeStream(in); IOUtils.closeStream(in);
disposeSasl();
// clean up all calls // clean up all calls
if (closeException == null) { if (closeException == null) {
@ -815,7 +900,7 @@ public class Client {
*/ */
@Deprecated @Deprecated
public Writable[] call(Writable[] params, InetSocketAddress[] addresses) public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
throws IOException { throws IOException, InterruptedException {
return call(params, addresses, null, null); return call(params, addresses, null, null);
} }
@ -825,7 +910,7 @@ public class Client {
* contains nulls for calls that timed out or errored. */ * contains nulls for calls that timed out or errored. */
public Writable[] call(Writable[] params, InetSocketAddress[] addresses, public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket) Class<?> protocol, UserGroupInformation ticket)
throws IOException { throws IOException, InterruptedException {
if (addresses.length == 0) return new Writable[0]; if (addresses.length == 0) return new Writable[0];
ParallelResults results = new ParallelResults(params.length); ParallelResults results = new ParallelResults(params.length);
@ -859,7 +944,7 @@ public class Client {
Class<?> protocol, Class<?> protocol,
UserGroupInformation ticket, UserGroupInformation ticket,
Call call) Call call)
throws IOException { throws IOException, InterruptedException {
if (!running.get()) { if (!running.get()) {
// the client is stopped // the client is stopped
throw new IOException("The client is stopped"); throw new IOException("The client is stopped");

View File

@ -37,6 +37,8 @@ import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.*; import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate; import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.ReflectionUtils;
@ -254,7 +256,7 @@ public class RPC {
@Deprecated @Deprecated
public static Object[] call(Method method, Object[][] params, public static Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs, Configuration conf) InetSocketAddress[] addrs, Configuration conf)
throws IOException { throws IOException, InterruptedException {
return call(method, params, addrs, null, conf); return call(method, params, addrs, null, conf);
} }
@ -262,7 +264,7 @@ public class RPC {
public static Object[] call(Method method, Object[][] params, public static Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs, InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf) UserGroupInformation ticket, Configuration conf)
throws IOException { throws IOException, InterruptedException {
return getProtocolEngine(method.getDeclaringClass(), conf) return getProtocolEngine(method.getDeclaringClass(), conf)
.call(method, params, addrs, ticket, conf); .call(method, params, addrs, ticket, conf);
@ -288,7 +290,7 @@ public class RPC {
final boolean verbose, Configuration conf) final boolean verbose, Configuration conf)
throws IOException { throws IOException {
return getServer(instance.getClass(), // use impl class for protocol return getServer(instance.getClass(), // use impl class for protocol
instance, bindAddress, port, numHandlers, false, conf); instance, bindAddress, port, numHandlers, false, conf, null);
} }
/** Construct a server for a protocol implementation instance. */ /** Construct a server for a protocol implementation instance. */
@ -296,19 +298,34 @@ public class RPC {
Object instance, String bindAddress, Object instance, String bindAddress,
int port, Configuration conf) int port, Configuration conf)
throws IOException { throws IOException {
return getServer(protocol, instance, bindAddress, port, 1, false, conf); return getServer(protocol, instance, bindAddress, port, 1, false, conf, null);
} }
/** Construct a server for a protocol implementation instance. */ /** Construct a server for a protocol implementation instance.
* @deprecated secretManager should be passed.
*/
@Deprecated
public static Server getServer(Class protocol, public static Server getServer(Class protocol,
Object instance, String bindAddress, int port, Object instance, String bindAddress, int port,
int numHandlers, int numHandlers,
boolean verbose, Configuration conf) boolean verbose, Configuration conf)
throws IOException { throws IOException {
return getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
conf, null);
}
/** Construct a server for a protocol implementation instance. */
public static Server getServer(Class<?> protocol,
Object instance, String bindAddress, int port,
int numHandlers,
boolean verbose, Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
return getProtocolEngine(protocol, conf) return getProtocolEngine(protocol, conf)
.getServer(protocol, instance, bindAddress, port, numHandlers, verbose, .getServer(protocol, instance, bindAddress, port, numHandlers, verbose,
conf); conf, secretManager);
} }
/** An RPC Server. */ /** An RPC Server. */
@ -316,8 +333,9 @@ public class RPC {
protected Server(String bindAddress, int port, protected Server(String bindAddress, int port,
Class<? extends Writable> paramClass, int handlerCount, Class<? extends Writable> paramClass, int handlerCount,
Configuration conf, String serverName) throws IOException { Configuration conf, String serverName,
super(bindAddress, port, paramClass, handlerCount, conf, serverName); SecretManager<? extends TokenIdentifier> secretManager) throws IOException {
super(bindAddress, port, paramClass, handlerCount, conf, serverName, secretManager);
} }
} }

View File

@ -24,6 +24,8 @@ import java.net.InetSocketAddress;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
/** An RPC implementation. */ /** An RPC implementation. */
@ -41,11 +43,13 @@ interface RpcEngine {
/** Expert: Make multiple, parallel calls to a set of servers. */ /** Expert: Make multiple, parallel calls to a set of servers. */
Object[] call(Method method, Object[][] params, InetSocketAddress[] addrs, Object[] call(Method method, Object[][] params, InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf) UserGroupInformation ticket, Configuration conf)
throws IOException; throws IOException, InterruptedException;
/** Construct a server for a protocol implementation instance. */ /** Construct a server for a protocol implementation instance. */
RPC.Server getServer(Class protocol, Object instance, String bindAddress, RPC.Server getServer(Class protocol, Object instance, String bindAddress,
int port, int numHandlers, boolean verbose, int port, int numHandlers, boolean verbose,
Configuration conf) throws IOException; Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager
) throws IOException;
} }

View File

@ -32,6 +32,7 @@ import java.net.SocketException;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException; import java.nio.channels.CancelledKeyException;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel; import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
@ -39,7 +40,6 @@ import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -52,15 +52,26 @@ import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.ipc.metrics.RpcMetrics; import org.apache.hadoop.ipc.metrics.RpcMetrics;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.SaslDigestCallbackHandler;
import org.apache.hadoop.security.SaslRpcServer.SaslGssCallbackHandler;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.AuthorizationException;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.util.StringUtils;
@ -80,7 +91,8 @@ public abstract class Server {
// 1 : Introduce ping and server does not throw away RPCs // 1 : Introduce ping and server does not throw away RPCs
// 3 : Introduce the protocol into the RPC connection header // 3 : Introduce the protocol into the RPC connection header
public static final byte CURRENT_VERSION = 3; // 4 : Introduced SASL security layer
public static final byte CURRENT_VERSION = 4;
/** /**
* How many calls/handler are allowed in the queue. * How many calls/handler are allowed in the queue.
@ -158,6 +170,7 @@ public abstract class Server {
protected RpcMetrics rpcMetrics; protected RpcMetrics rpcMetrics;
private Configuration conf; private Configuration conf;
private SecretManager<TokenIdentifier> secretManager;
private int maxQueueSize; private int maxQueueSize;
private int socketSendBufferSize; private int socketSendBufferSize;
@ -431,7 +444,7 @@ public abstract class Server {
if (count < 0) { if (count < 0) {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug(getName() + ": disconnecting client " + LOG.debug(getName() + ": disconnecting client " +
c.getHostAddress() + ". Number of active connections: "+ c + ". Number of active connections: "+
numConnections); numConnections);
closeConnection(c); closeConnection(c);
c = null; c = null;
@ -703,8 +716,7 @@ public abstract class Server {
/** Reads calls from a connection and queues them for handling. */ /** Reads calls from a connection and queues them for handling. */
private class Connection { private class Connection {
private boolean versionRead = false; //if initial signature and private boolean rpcHeaderRead = false; // if initial rpc header is read
//version are read
private boolean headerRead = false; //if the connection header that private boolean headerRead = false; //if the connection header that
//follows version is read. //follows version is read.
@ -723,6 +735,13 @@ public abstract class Server {
ConnectionHeader header = new ConnectionHeader(); ConnectionHeader header = new ConnectionHeader();
Class<?> protocol; Class<?> protocol;
boolean useSasl;
SaslServer saslServer;
private AuthMethod authMethod;
private boolean saslContextEstablished;
private ByteBuffer rpcHeaderBuffer;
private ByteBuffer unwrappedData;
private ByteBuffer unwrappedDataLengthBuffer;
UserGroupInformation user = null; UserGroupInformation user = null;
@ -731,6 +750,10 @@ public abstract class Server {
private final Call authFailedCall = private final Call authFailedCall =
new Call(AUTHROIZATION_FAILED_CALLID, null, null); new Call(AUTHROIZATION_FAILED_CALLID, null, null);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
// Fake 'call' for SASL context setup
private static final int SASL_CALLID = -33;
private final Call saslCall = new Call(SASL_CALLID, null, null);
private final ByteArrayOutputStream saslResponse = new ByteArrayOutputStream();
public Connection(SelectionKey key, SocketChannel channel, public Connection(SelectionKey key, SocketChannel channel,
long lastContact) { long lastContact) {
@ -738,6 +761,8 @@ public abstract class Server {
this.lastContact = lastContact; this.lastContact = lastContact;
this.data = null; this.data = null;
this.dataLengthBuffer = ByteBuffer.allocate(4); this.dataLengthBuffer = ByteBuffer.allocate(4);
this.unwrappedData = null;
this.unwrappedDataLengthBuffer = ByteBuffer.allocate(4);
this.socket = channel.socket(); this.socket = channel.socket();
InetAddress addr = socket.getInetAddress(); InetAddress addr = socket.getInetAddress();
if (addr == null) { if (addr == null) {
@ -795,6 +820,92 @@ public abstract class Server {
return false; return false;
} }
private void saslReadAndProcess(byte[] saslToken) throws IOException,
InterruptedException {
if (!saslContextEstablished) {
if (saslServer == null) {
switch (authMethod) {
case DIGEST:
saslServer = Sasl.createSaslServer(AuthMethod.DIGEST
.getMechanismName(), null, SaslRpcServer.SASL_DEFAULT_REALM,
SaslRpcServer.SASL_PROPS, new SaslDigestCallbackHandler(
secretManager));
break;
default:
UserGroupInformation current = UserGroupInformation
.getCurrentUser();
String fullName = current.getUserName();
if (LOG.isDebugEnabled())
LOG.debug("Kerberos principal name is " + fullName);
final String names[] = SaslRpcServer.splitKerberosName(fullName);
if (names.length != 3) {
throw new IOException(
"Kerberos principal name does NOT have the expected "
+ "hostname part: " + fullName);
}
current.doAs(new PrivilegedExceptionAction<Object>() {
@Override
public Object run() throws IOException {
saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS
.getMechanismName(), names[0], names[1],
SaslRpcServer.SASL_PROPS, new SaslGssCallbackHandler());
return null;
}
});
}
if (saslServer == null)
throw new IOException(
"Unable to find SASL server implementation for "
+ authMethod.getMechanismName());
if (LOG.isDebugEnabled())
LOG.debug("Created SASL server with mechanism = "
+ authMethod.getMechanismName());
}
if (LOG.isDebugEnabled())
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.evaluateResponse()");
byte[] replyToken = saslServer.evaluateResponse(saslToken);
if (replyToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + replyToken.length
+ " from saslServer.");
saslCall.connection = this;
saslResponse.reset();
DataOutputStream out = new DataOutputStream(saslResponse);
out.writeInt(replyToken.length);
out.write(replyToken, 0, replyToken.length);
saslCall.setResponse(ByteBuffer.wrap(saslResponse.toByteArray()));
responder.doRespond(saslCall);
}
if (saslServer.isComplete()) {
if (LOG.isDebugEnabled()) {
LOG.debug("SASL server context established. Negotiated QoP is "
+ saslServer.getNegotiatedProperty(Sasl.QOP));
}
user = UserGroupInformation.createRemoteUser(saslServer
.getAuthorizationID());
LOG.info("SASL server successfully authenticated client: " + user);
saslContextEstablished = true;
}
} else {
if (LOG.isDebugEnabled())
LOG.debug("Have read input token of size " + saslToken.length
+ " for processing by saslServer.unwrap()");
byte[] plaintextData = saslServer
.unwrap(saslToken, 0, saslToken.length);
processUnwrappedData(plaintextData);
}
}
private void disposeSasl() {
if (saslServer != null) {
try {
saslServer.dispose();
} catch (SaslException ignored) {
}
}
}
public int readAndProcess() throws IOException, InterruptedException { public int readAndProcess() throws IOException, InterruptedException {
while (true) { while (true) {
/* Read at most one RPC. If the header is not read completely yet /* Read at most one RPC. If the header is not read completely yet
@ -807,14 +918,33 @@ public abstract class Server {
return count; return count;
} }
if (!versionRead) { if (!rpcHeaderRead) {
//Every connection is expected to send the header. //Every connection is expected to send the header.
ByteBuffer versionBuffer = ByteBuffer.allocate(1); if (rpcHeaderBuffer == null) {
count = channelRead(channel, versionBuffer); rpcHeaderBuffer = ByteBuffer.allocate(2);
if (count <= 0) { }
count = channelRead(channel, rpcHeaderBuffer);
if (count < 0 || rpcHeaderBuffer.remaining() > 0) {
return count; return count;
} }
int version = versionBuffer.get(0); int version = rpcHeaderBuffer.get(0);
byte[] method = new byte[] {rpcHeaderBuffer.get(1)};
authMethod = AuthMethod.read(new DataInputStream(
new ByteArrayInputStream(method)));
if (authMethod == null) {
throw new IOException("Unable to read authentication method");
}
if (UserGroupInformation.isSecurityEnabled()
&& authMethod == AuthMethod.SIMPLE) {
throw new IOException("Authentication is required");
}
if (!UserGroupInformation.isSecurityEnabled()
&& authMethod != AuthMethod.SIMPLE) {
throw new IOException("Authentication is not supported");
}
if (authMethod != AuthMethod.SIMPLE) {
useSasl = true;
}
dataLengthBuffer.flip(); dataLengthBuffer.flip();
if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) { if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) {
@ -826,7 +956,8 @@ public abstract class Server {
return -1; return -1;
} }
dataLengthBuffer.clear(); dataLengthBuffer.clear();
versionRead = true; rpcHeaderBuffer = null;
rpcHeaderRead = true;
continue; continue;
} }
@ -834,12 +965,11 @@ public abstract class Server {
dataLengthBuffer.flip(); dataLengthBuffer.flip();
dataLength = dataLengthBuffer.getInt(); dataLength = dataLengthBuffer.getInt();
if (dataLength == Client.PING_CALL_ID) { if (!useSasl && dataLength == Client.PING_CALL_ID) {
dataLengthBuffer.clear(); dataLengthBuffer.clear();
return 0; //ping message return 0; //ping message
} }
data = ByteBuffer.allocate(dataLength); data = ByteBuffer.allocate(dataLength);
incRpcCount(); // Increment the rpc count
} }
count = channelRead(channel, data); count = channelRead(channel, data);
@ -847,33 +977,14 @@ public abstract class Server {
if (data.remaining() == 0) { if (data.remaining() == 0) {
dataLengthBuffer.clear(); dataLengthBuffer.clear();
data.flip(); data.flip();
if (headerRead) { boolean isHeaderRead = headerRead;
processData(); if (useSasl) {
data = null; saslReadAndProcess(data.array());
return count;
} else { } else {
processHeader(); processOneRpc(data.array());
headerRead = true; }
data = null; data = null;
if (!isHeaderRead) {
// Authorize the connection
try {
authorize(user, header);
if (LOG.isDebugEnabled()) {
LOG.debug("Successfully authorized " + header);
}
} catch (AuthorizationException ae) {
authFailedCall.connection = this;
setupResponse(authFailedResponse, authFailedCall,
Status.FATAL, null,
ae.getClass().getName(), ae.getMessage());
responder.doRespond(authFailedCall);
// Close this connection
return -1;
}
continue; continue;
} }
} }
@ -882,9 +993,9 @@ public abstract class Server {
} }
/// Reads the connection header following version /// Reads the connection header following version
private void processHeader() throws IOException { private void processHeader(byte[] buf) throws IOException {
DataInputStream in = DataInputStream in =
new DataInputStream(new ByteArrayInputStream(data.array())); new DataInputStream(new ByteArrayInputStream(buf));
header.readFields(in); header.readFields(in);
try { try {
String protocolClassName = header.getProtocol(); String protocolClassName = header.getProtocol();
@ -895,12 +1006,73 @@ public abstract class Server {
throw new IOException("Unknown protocol: " + header.getProtocol()); throw new IOException("Unknown protocol: " + header.getProtocol());
} }
user = header.getUgi(); UserGroupInformation protocolUser = header.getUgi();
if (!useSasl) {
user = protocolUser;
} else if (protocolUser != null && !protocolUser.equals(user)) {
throw new AccessControlException("Authenticated user (" + user
+ ") doesn't match what the client claims to be (" + protocolUser
+ ")");
}
} }
private void processData() throws IOException, InterruptedException { private void processUnwrappedData(byte[] inBuf) throws IOException,
InterruptedException {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
inBuf));
// Read all RPCs contained in the inBuf, even partial ones
while (true) {
int count = -1;
if (unwrappedDataLengthBuffer.remaining() > 0) {
count = channelRead(ch, unwrappedDataLengthBuffer);
if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0)
return;
}
if (unwrappedData == null) {
unwrappedDataLengthBuffer.flip();
int unwrappedDataLength = unwrappedDataLengthBuffer.getInt();
if (unwrappedDataLength == Client.PING_CALL_ID) {
if (LOG.isDebugEnabled())
LOG.debug("Received ping message");
unwrappedDataLengthBuffer.clear();
continue; // ping message
}
unwrappedData = ByteBuffer.allocate(unwrappedDataLength);
}
count = channelRead(ch, unwrappedData);
if (count <= 0 || unwrappedData.remaining() > 0)
return;
if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear();
unwrappedData.flip();
processOneRpc(unwrappedData.array());
unwrappedData = null;
}
}
}
private void processOneRpc(byte[] buf) throws IOException,
InterruptedException {
if (headerRead) {
processData(buf);
} else {
processHeader(buf);
headerRead = true;
if (!authorizeConnection()) {
throw new AccessControlException("Connection from " + this
+ " for protocol " + header.getProtocol()
+ " is unauthorized for user " + user);
}
}
}
private void processData(byte[] buf) throws IOException, InterruptedException {
DataInputStream dis = DataInputStream dis =
new DataInputStream(new ByteArrayInputStream(data.array())); new DataInputStream(new ByteArrayInputStream(buf));
int id = dis.readInt(); // try to read an id int id = dis.readInt(); // try to read an id
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
@ -911,9 +1083,27 @@ public abstract class Server {
Call call = new Call(id, param, this); Call call = new Call(id, param, this);
callQueue.put(call); // queue the call; maybe blocked here callQueue.put(call); // queue the call; maybe blocked here
incRpcCount(); // Increment the rpc count
} }
private boolean authorizeConnection() throws IOException {
try {
authorize(user, header);
if (LOG.isDebugEnabled()) {
LOG.debug("Successfully authorized " + header);
}
} catch (AuthorizationException ae) {
authFailedCall.connection = this;
setupResponse(authFailedResponse, authFailedCall, Status.FATAL, null,
ae.getClass().getName(), ae.getMessage());
responder.doRespond(authFailedCall);
return false;
}
return true;
}
private synchronized void close() throws IOException { private synchronized void close() throws IOException {
disposeSasl();
data = null; data = null;
dataLengthBuffer = null; dataLengthBuffer = null;
if (!channel.isOpen()) if (!channel.isOpen())
@ -1011,16 +1201,17 @@ public abstract class Server {
Configuration conf) Configuration conf)
throws IOException throws IOException
{ {
this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port)); this(bindAddress, port, paramClass, handlerCount, conf, Integer.toString(port), null);
} }
/** Constructs a server listening on the named port and address. Parameters passed must /** Constructs a server listening on the named port and address. Parameters passed must
* be of the named class. The <code>handlerCount</handlerCount> determines * be of the named class. The <code>handlerCount</handlerCount> determines
* the number of handler threads that will be used to process calls. * the number of handler threads that will be used to process calls.
* *
*/ */
@SuppressWarnings("unchecked")
protected Server(String bindAddress, int port, protected Server(String bindAddress, int port,
Class<? extends Writable> paramClass, int handlerCount, Class<? extends Writable> paramClass, int handlerCount,
Configuration conf, String serverName) Configuration conf, String serverName, SecretManager<? extends TokenIdentifier> secretManager)
throws IOException { throws IOException {
this.bindAddress = bindAddress; this.bindAddress = bindAddress;
this.conf = conf; this.conf = conf;
@ -1033,6 +1224,7 @@ public abstract class Server {
this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000); this.maxIdleTime = 2*conf.getInt("ipc.client.connection.maxidletime", 1000);
this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10); this.maxConnectionsToNuke = conf.getInt("ipc.client.kill.max", 10);
this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000); this.thresholdIdleConnections = conf.getInt("ipc.client.idlethreshold", 4000);
this.secretManager = (SecretManager<TokenIdentifier>) secretManager;
this.authorize = this.authorize =
conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG, conf.getBoolean(ServiceAuthorizationManager.SERVICE_AUTHORIZATION_CONFIG,
false); false);
@ -1086,9 +1278,29 @@ public abstract class Server {
WritableUtils.writeString(out, errorClass); WritableUtils.writeString(out, errorClass);
WritableUtils.writeString(out, error); WritableUtils.writeString(out, error);
} }
wrapWithSasl(response, call);
call.setResponse(ByteBuffer.wrap(response.toByteArray())); call.setResponse(ByteBuffer.wrap(response.toByteArray()));
} }
private void wrapWithSasl(ByteArrayOutputStream response, Call call)
throws IOException {
if (call.connection.useSasl) {
byte[] token = response.toByteArray();
// synchronization may be needed since there can be multiple Handler
// threads using saslServer to wrap responses.
synchronized (call.connection.saslServer) {
token = call.connection.saslServer.wrap(token, 0, token.length);
}
if (LOG.isDebugEnabled())
LOG.debug("Adding saslServer wrapped token of size " + token.length
+ " as call response.");
response.reset();
DataOutputStream saslOut = new DataOutputStream(response);
saslOut.writeInt(token.length);
saslOut.write(token, 0, token.length);
}
}
Configuration getConf() { Configuration getConf() {
return conf; return conf;
} }

View File

@ -36,6 +36,8 @@ import org.apache.commons.logging.*;
import org.apache.hadoop.io.*; import org.apache.hadoop.io.*;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.conf.*; import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate; import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
@ -246,7 +248,7 @@ class WritableRpcEngine implements RpcEngine {
public Object[] call(Method method, Object[][] params, public Object[] call(Method method, Object[][] params,
InetSocketAddress[] addrs, InetSocketAddress[] addrs,
UserGroupInformation ticket, Configuration conf) UserGroupInformation ticket, Configuration conf)
throws IOException { throws IOException, InterruptedException {
Invocation[] invocations = new Invocation[params.length]; Invocation[] invocations = new Invocation[params.length];
for (int i = 0; i < params.length; i++) for (int i = 0; i < params.length; i++)
@ -276,9 +278,11 @@ class WritableRpcEngine implements RpcEngine {
* port and address. */ * port and address. */
public Server getServer(Class protocol, public Server getServer(Class protocol,
Object instance, String bindAddress, int port, Object instance, String bindAddress, int port,
int numHandlers, boolean verbose, Configuration conf) int numHandlers, boolean verbose, Configuration conf,
SecretManager<? extends TokenIdentifier> secretManager)
throws IOException { throws IOException {
return new Server(instance, conf, bindAddress, port, numHandlers, verbose); return new Server(instance, conf, bindAddress, port, numHandlers,
verbose, secretManager);
} }
/** An RPC Server. */ /** An RPC Server. */
@ -294,7 +298,7 @@ class WritableRpcEngine implements RpcEngine {
*/ */
public Server(Object instance, Configuration conf, String bindAddress, int port) public Server(Object instance, Configuration conf, String bindAddress, int port)
throws IOException { throws IOException {
this(instance, conf, bindAddress, port, 1, false); this(instance, conf, bindAddress, port, 1, false, null);
} }
private static String classNameBase(String className) { private static String classNameBase(String className) {
@ -314,8 +318,11 @@ class WritableRpcEngine implements RpcEngine {
* @param verbose whether each call should be logged * @param verbose whether each call should be logged
*/ */
public Server(Object instance, Configuration conf, String bindAddress, int port, public Server(Object instance, Configuration conf, String bindAddress, int port,
int numHandlers, boolean verbose) throws IOException { int numHandlers, boolean verbose,
super(bindAddress, port, Invocation.class, numHandlers, conf, classNameBase(instance.getClass().getName())); SecretManager<? extends TokenIdentifier> secretManager)
throws IOException {
super(bindAddress, port, Invocation.class, numHandlers, conf,
classNameBase(instance.getClass().getName()), secretManager);
this.instance = instance; this.instance = instance;
this.verbose = verbose; this.verbose = verbose;
} }

View File

@ -0,0 +1,31 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.lang.annotation.*;
/**
* Indicates Kerberos related information to be used
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface KerberosInfo {
/** Key for getting server's Kerberos principal name from Configuration */
String value();
}

View File

@ -0,0 +1,321 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.InputStream;
import java.io.IOException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* A SaslInputStream is composed of an InputStream and a SaslServer (or
* SaslClient) so that read() methods return data that are read in from the
* underlying InputStream but have been additionally processed by the SaslServer
* (or SaslClient) object. The SaslServer (or SaslClient) object must be fully
* initialized before being used by a SaslInputStream.
*/
public class SaslInputStream extends InputStream {
public static final Log LOG = LogFactory.getLog(SaslInputStream.class);
private final DataInputStream inStream;
/*
* data read from the underlying input stream before being processed by SASL
*/
private byte[] saslToken;
private final SaslClient saslClient;
private final SaslServer saslServer;
private byte[] lengthBuf = new byte[4];
/*
* buffer holding data that have been processed by SASL, but have not been
* read out
*/
private byte[] obuffer;
// position of the next "new" byte
private int ostart = 0;
// position of the last "new" byte
private int ofinish = 0;
private static int unsignedBytesToInt(byte[] buf) {
if (buf.length != 4) {
throw new IllegalArgumentException(
"Cannot handle byte array other than 4 bytes");
}
int result = 0;
for (int i = 0; i < 4; i++) {
result <<= 8;
result |= ((int) buf[i] & 0xff);
}
return result;
}
/**
* Read more data and get them processed <br>
* Entry condition: ostart = ofinish <br>
* Exit condition: ostart <= ofinish <br>
*
* return (ofinish-ostart) (we have this many bytes for you), 0 (no data now,
* but could have more later), or -1 (absolutely no more data)
*/
private int readMoreData() throws IOException {
try {
inStream.readFully(lengthBuf);
int length = unsignedBytesToInt(lengthBuf);
if (LOG.isDebugEnabled())
LOG.debug("Actual length is " + length);
saslToken = new byte[length];
inStream.readFully(saslToken);
} catch (EOFException e) {
return -1;
}
try {
if (saslServer != null) { // using saslServer
obuffer = saslServer.unwrap(saslToken, 0, saslToken.length);
} else { // using saslClient
obuffer = saslClient.unwrap(saslToken, 0, saslToken.length);
}
} catch (SaslException se) {
try {
disposeSasl();
} catch (SaslException ignored) {
}
throw se;
}
ostart = 0;
if (obuffer == null)
ofinish = 0;
else
ofinish = obuffer.length;
return ofinish;
}
/**
* Disposes of any system resources or security-sensitive information Sasl
* might be using.
*
* @exception SaslException
* if a SASL error occurs.
*/
private void disposeSasl() throws SaslException {
if (saslClient != null) {
saslClient.dispose();
}
if (saslServer != null) {
saslServer.dispose();
}
}
/**
* Constructs a SASLInputStream from an InputStream and a SaslServer <br>
* Note: if the specified InputStream or SaslServer is null, a
* NullPointerException may be thrown later when they are used.
*
* @param inStream
* the InputStream to be processed
* @param saslServer
* an initialized SaslServer object
*/
public SaslInputStream(InputStream inStream, SaslServer saslServer) {
this.inStream = new DataInputStream(inStream);
this.saslServer = saslServer;
this.saslClient = null;
}
/**
* Constructs a SASLInputStream from an InputStream and a SaslClient <br>
* Note: if the specified InputStream or SaslClient is null, a
* NullPointerException may be thrown later when they are used.
*
* @param inStream
* the InputStream to be processed
* @param saslClient
* an initialized SaslClient object
*/
public SaslInputStream(InputStream inStream, SaslClient saslClient) {
this.inStream = new DataInputStream(inStream);
this.saslServer = null;
this.saslClient = saslClient;
}
/**
* Reads the next byte of data from this input stream. The value byte is
* returned as an <code>int</code> in the range <code>0</code> to
* <code>255</code>. If no byte is available because the end of the stream has
* been reached, the value <code>-1</code> is returned. This method blocks
* until input data is available, the end of the stream is detected, or an
* exception is thrown.
* <p>
*
* @return the next byte of data, or <code>-1</code> if the end of the stream
* is reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read() throws IOException {
if (ostart >= ofinish) {
// we loop for new data as we are blocking
int i = 0;
while (i == 0)
i = readMoreData();
if (i == -1)
return -1;
}
return ((int) obuffer[ostart++] & 0xff);
}
/**
* Reads up to <code>b.length</code> bytes of data from this input stream into
* an array of bytes.
* <p>
* The <code>read</code> method of <code>InputStream</code> calls the
* <code>read</code> method of three arguments with the arguments
* <code>b</code>, <code>0</code>, and <code>b.length</code>.
*
* @param b
* the buffer into which the data is read.
* @return the total number of bytes read into the buffer, or <code>-1</code>
* is there is no more data because the end of the stream has been
* reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
/**
* Reads up to <code>len</code> bytes of data from this input stream into an
* array of bytes. This method blocks until some input is available. If the
* first argument is <code>null,</code> up to <code>len</code> bytes are read
* and discarded.
*
* @param b
* the buffer into which the data is read.
* @param off
* the start offset of the data.
* @param len
* the maximum number of bytes read.
* @return the total number of bytes read into the buffer, or <code>-1</code>
* if there is no more data because the end of the stream has been
* reached.
* @exception IOException
* if an I/O error occurs.
*/
public int read(byte[] b, int off, int len) throws IOException {
if (ostart >= ofinish) {
// we loop for new data as we are blocking
int i = 0;
while (i == 0)
i = readMoreData();
if (i == -1)
return -1;
}
if (len <= 0) {
return 0;
}
int available = ofinish - ostart;
if (len < available)
available = len;
if (b != null) {
System.arraycopy(obuffer, ostart, b, off, available);
}
ostart = ostart + available;
return available;
}
/**
* Skips <code>n</code> bytes of input from the bytes that can be read from
* this input stream without blocking.
*
* <p>
* Fewer bytes than requested might be skipped. The actual number of bytes
* skipped is equal to <code>n</code> or the result of a call to
* {@link #available() <code>available</code>}, whichever is smaller. If
* <code>n</code> is less than zero, no bytes are skipped.
*
* <p>
* The actual number of bytes skipped is returned.
*
* @param n
* the number of bytes to be skipped.
* @return the actual number of bytes skipped.
* @exception IOException
* if an I/O error occurs.
*/
public long skip(long n) throws IOException {
int available = ofinish - ostart;
if (n > available) {
n = available;
}
if (n < 0) {
return 0;
}
ostart += n;
return n;
}
/**
* Returns the number of bytes that can be read from this input stream without
* blocking. The <code>available</code> method of <code>InputStream</code>
* returns <code>0</code>. This method <B>should</B> be overridden by
* subclasses.
*
* @return the number of bytes that can be read from this input stream without
* blocking.
* @exception IOException
* if an I/O error occurs.
*/
public int available() throws IOException {
return (ofinish - ostart);
}
/**
* Closes this input stream and releases any system resources associated with
* the stream.
* <p>
* The <code>close</code> method of <code>SASLInputStream</code> calls the
* <code>close</code> method of its underlying input stream.
*
* @exception IOException
* if an I/O error occurs.
*/
public void close() throws IOException {
disposeSasl();
ostart = 0;
ofinish = 0;
inStream.close();
}
/**
* Tests if this input stream supports the <code>mark</code> and
* <code>reset</code> methods, which it does not.
*
* @return <code>false</code>, since this class does not support the
* <code>mark</code> and <code>reset</code> methods.
*/
public boolean markSupported() {
return false;
}
}

View File

@ -0,0 +1,181 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
/**
* A SaslOutputStream is composed of an OutputStream and a SaslServer (or
* SaslClient) so that write() methods first process the data before writing
* them out to the underlying OutputStream. The SaslServer (or SaslClient)
* object must be fully initialized before being used by a SaslOutputStream.
*/
public class SaslOutputStream extends OutputStream {
private final DataOutputStream outStream;
// processed data ready to be written out
private byte[] saslToken;
private final SaslClient saslClient;
private final SaslServer saslServer;
// buffer holding one byte of incoming data
private final byte[] ibuffer = new byte[1];
/**
* Constructs a SASLOutputStream from an OutputStream and a SaslServer <br>
* Note: if the specified OutputStream or SaslServer is null, a
* NullPointerException may be thrown later when they are used.
*
* @param outStream
* the OutputStream to be processed
* @param saslServer
* an initialized SaslServer object
*/
public SaslOutputStream(OutputStream outStream, SaslServer saslServer) {
this.outStream = new DataOutputStream(outStream);
this.saslServer = saslServer;
this.saslClient = null;
}
/**
* Constructs a SASLOutputStream from an OutputStream and a SaslClient <br>
* Note: if the specified OutputStream or SaslClient is null, a
* NullPointerException may be thrown later when they are used.
*
* @param outStream
* the OutputStream to be processed
* @param saslClient
* an initialized SaslClient object
*/
public SaslOutputStream(OutputStream outStream, SaslClient saslClient) {
this.outStream = new DataOutputStream(outStream);
this.saslServer = null;
this.saslClient = saslClient;
}
/**
* Disposes of any system resources or security-sensitive information Sasl
* might be using.
*
* @exception SaslException
* if a SASL error occurs.
*/
private void disposeSasl() throws SaslException {
if (saslClient != null) {
saslClient.dispose();
}
if (saslServer != null) {
saslServer.dispose();
}
}
/**
* Writes the specified byte to this output stream.
*
* @param b
* the <code>byte</code>.
* @exception IOException
* if an I/O error occurs.
*/
public void write(int b) throws IOException {
ibuffer[0] = (byte) b;
write(ibuffer, 0, 1);
}
/**
* Writes <code>b.length</code> bytes from the specified byte array to this
* output stream.
* <p>
* The <code>write</code> method of <code>SASLOutputStream</code> calls the
* <code>write</code> method of three arguments with the three arguments
* <code>b</code>, <code>0</code>, and <code>b.length</code>.
*
* @param b
* the data.
* @exception NullPointerException
* if <code>b</code> is null.
* @exception IOException
* if an I/O error occurs.
*/
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
/**
* Writes <code>len</code> bytes from the specified byte array starting at
* offset <code>off</code> to this output stream.
*
* @param inBuf
* the data.
* @param off
* the start offset in the data.
* @param len
* the number of bytes to write.
* @exception IOException
* if an I/O error occurs.
*/
public void write(byte[] inBuf, int off, int len) throws IOException {
try {
if (saslServer != null) { // using saslServer
saslToken = saslServer.wrap(inBuf, off, len);
} else { // using saslClient
saslToken = saslClient.wrap(inBuf, off, len);
}
} catch (SaslException se) {
try {
disposeSasl();
} catch (SaslException ignored) {
}
throw se;
}
if (saslToken != null) {
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
saslToken = null;
}
}
/**
* Flushes this output stream
*
* @exception IOException
* if an I/O error occurs.
*/
public void flush() throws IOException {
outStream.flush();
}
/**
* Closes this output stream and releases any system resources associated with
* this stream.
*
* @exception IOException
* if an I/O error occurs.
*/
public void close() throws IOException {
disposeSasl();
outStream.close();
}
}

View File

@ -0,0 +1,249 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslClient;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
/**
* A utility class that encapsulates SASL logic for RPC client
*/
public class SaslRpcClient {
public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
private final SaslClient saslClient;
/**
* Create a SaslRpcClient for an authentication method
*
* @param method
* the requested authentication method
* @param token
* token to use if needed by the authentication method
*/
public SaslRpcClient(AuthMethod method,
Token<? extends TokenIdentifier> token, String serverPrincipal)
throws IOException {
switch (method) {
case DIGEST:
if (LOG.isDebugEnabled())
LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
+ " client to authenticate to service at " + token.getService());
saslClient = Sasl.createSaslClient(new String[] { AuthMethod.DIGEST
.getMechanismName() }, null, null, SaslRpcServer.SASL_DEFAULT_REALM,
SaslRpcServer.SASL_PROPS, new SaslClientCallbackHandler(token));
break;
case KERBEROS:
if (LOG.isDebugEnabled()) {
LOG
.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
+ " client. Server's Kerberos principal name is "
+ serverPrincipal);
}
if (serverPrincipal == null || serverPrincipal.length() == 0) {
throw new IOException(
"Failed to specify server's Kerberos principal name");
}
String names[] = SaslRpcServer.splitKerberosName(serverPrincipal);
if (names.length != 3) {
throw new IOException(
"Kerberos principal name does NOT have the expected hostname part: "
+ serverPrincipal);
}
saslClient = Sasl.createSaslClient(new String[] { AuthMethod.KERBEROS
.getMechanismName() }, null, names[0], names[1],
SaslRpcServer.SASL_PROPS, null);
break;
default:
throw new IOException("Unknown authentication method " + method);
}
if (saslClient == null)
throw new IOException("Unable to find SASL client implementation");
}
/**
* Do client side SASL authentication with server via the given InputStream
* and OutputStream
*
* @param inS
* InputStream to use
* @param outS
* OutputStream to use
* @throws IOException
*/
public void saslConnect(InputStream inS, OutputStream outS)
throws IOException {
DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(
outS));
try {
byte[] saslToken = new byte[0];
if (saslClient.hasInitialResponse())
saslToken = saslClient.evaluateChallenge(saslToken);
if (saslToken != null) {
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
outStream.flush();
if (LOG.isDebugEnabled())
LOG.debug("Have sent token of size " + saslToken.length
+ " from initSASLContext.");
}
if (!saslClient.isComplete()) {
saslToken = new byte[inStream.readInt()];
if (LOG.isDebugEnabled())
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
inStream.readFully(saslToken);
}
while (!saslClient.isComplete()) {
saslToken = saslClient.evaluateChallenge(saslToken);
if (saslToken != null) {
if (LOG.isDebugEnabled())
LOG.debug("Will send token of size " + saslToken.length
+ " from initSASLContext.");
outStream.writeInt(saslToken.length);
outStream.write(saslToken, 0, saslToken.length);
outStream.flush();
}
if (!saslClient.isComplete()) {
saslToken = new byte[inStream.readInt()];
if (LOG.isDebugEnabled())
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
inStream.readFully(saslToken);
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("SASL client context established. Negotiated QoP: "
+ saslClient.getNegotiatedProperty(Sasl.QOP));
}
} catch (IOException e) {
saslClient.dispose();
throw e;
}
}
/**
* Get a SASL wrapped InputStream. Can be called only after saslConnect() has
* been called.
*
* @param in
* the InputStream to wrap
* @return a SASL wrapped InputStream
* @throws IOException
*/
public InputStream getInputStream(InputStream in) throws IOException {
if (!saslClient.isComplete()) {
throw new IOException("Sasl authentication exchange hasn't completed yet");
}
return new SaslInputStream(in, saslClient);
}
/**
* Get a SASL wrapped OutputStream. Can be called only after saslConnect() has
* been called.
*
* @param out
* the OutputStream to wrap
* @return a SASL wrapped OutputStream
* @throws IOException
*/
public OutputStream getOutputStream(OutputStream out) throws IOException {
if (!saslClient.isComplete()) {
throw new IOException("Sasl authentication exchange hasn't completed yet");
}
return new SaslOutputStream(out, saslClient);
}
/** Release resources used by wrapped saslClient */
public void dispose() throws SaslException {
saslClient.dispose();
}
private static class SaslClientCallbackHandler implements CallbackHandler {
private final String userName;
private final char[] userPassword;
public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) {
this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier());
this.userPassword = SaslRpcServer.encodePassword(token.getPassword());
}
public void handle(Callback[] callbacks)
throws UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
RealmCallback rc = null;
for (Callback callback : callbacks) {
if (callback instanceof RealmChoiceCallback) {
continue;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
rc = (RealmCallback) callback;
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL client callback");
}
}
if (nc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting username: " + userName);
nc.setName(userName);
}
if (pc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting userPassword");
pc.setPassword(userPassword);
}
if (rc != null) {
if (LOG.isDebugEnabled())
LOG.debug("SASL client callback: setting realm: "
+ rc.getDefaultText());
rc.setText(rc.getDefaultText());
}
}
}
}

View File

@ -0,0 +1,218 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import java.util.TreeMap;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
/**
* A utility class for dealing with SASL on RPC server
*/
public class SaslRpcServer {
public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
public static final String SASL_DEFAULT_REALM = "default";
public static final Map<String, String> SASL_PROPS =
new TreeMap<String, String>();
static {
// Request authentication plus integrity protection
SASL_PROPS.put(Sasl.QOP, "auth-int");
// Request mutual authentication
SASL_PROPS.put(Sasl.SERVER_AUTH, "true");
}
static String encodeIdentifier(byte[] identifier) {
return new String(Base64.encodeBase64(identifier));
}
static byte[] decodeIdentifier(String identifier) {
return Base64.decodeBase64(identifier.getBytes());
}
static char[] encodePassword(byte[] password) {
return new String(Base64.encodeBase64(password)).toCharArray();
}
/** Splitting fully qualified Kerberos name into parts */
public static String[] splitKerberosName(String fullName) {
return fullName.split("[/@]");
}
/** Authentication method */
public static enum AuthMethod {
SIMPLE((byte) 80, ""), // no authentication
KERBEROS((byte) 81, "GSSAPI"), // SASL Kerberos authentication
DIGEST((byte) 82, "DIGEST-MD5"); // SASL DIGEST-MD5 authentication
/** The code for this method. */
public final byte code;
public final String mechanismName;
private AuthMethod(byte code, String mechanismName) {
this.code = code;
this.mechanismName = mechanismName;
}
private static final int FIRST_CODE = values()[0].code;
/** Return the object represented by the code. */
private static AuthMethod valueOf(byte code) {
final int i = (code & 0xff) - FIRST_CODE;
return i < 0 || i >= values().length ? null : values()[i];
}
/** Return the SASL mechanism name */
public String getMechanismName() {
return mechanismName;
}
/** Read from in */
public static AuthMethod read(DataInput in) throws IOException {
return valueOf(in.readByte());
}
/** Write to out */
public void write(DataOutput out) throws IOException {
out.write(code);
}
};
/** CallbackHandler for SASL DIGEST-MD5 mechanism */
public static class SaslDigestCallbackHandler implements CallbackHandler {
private SecretManager<TokenIdentifier> secretManager;
public SaslDigestCallbackHandler(
SecretManager<TokenIdentifier> secretManager) {
this.secretManager = secretManager;
}
private TokenIdentifier getIdentifier(String id) throws IOException {
byte[] tokenId = decodeIdentifier(id);
TokenIdentifier tokenIdentifier = secretManager.createIdentifier();
tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
tokenId)));
return tokenIdentifier;
}
private char[] getPassword(TokenIdentifier tokenid) throws IOException {
return encodePassword(secretManager.retrievePassword(tokenid));
}
/** {@inheritDoc} */
@Override
public void handle(Callback[] callbacks) throws IOException,
UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
AuthorizeCallback ac = null;
for (Callback callback : callbacks) {
if (callback instanceof AuthorizeCallback) {
ac = (AuthorizeCallback) callback;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
continue; // realm is ignored
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL DIGEST-MD5 Callback");
}
}
if (pc != null) {
TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName());
char[] password = getPassword(tokenIdentifier);
if (LOG.isDebugEnabled()) {
LOG.debug("SASL server DIGEST-MD5 callback: setting password "
+ "for client: " + tokenIdentifier.getUsername());
}
pc.setPassword(password);
}
if (ac != null) {
String authid = ac.getAuthenticationID();
String authzid = ac.getAuthorizationID();
if (authid.equals(authzid)) {
ac.setAuthorized(true);
} else {
ac.setAuthorized(false);
}
if (ac.isAuthorized()) {
String username = getIdentifier(authzid).getUsername().toString();
if (LOG.isDebugEnabled())
LOG.debug("SASL server DIGEST-MD5 callback: setting "
+ "canonicalized client ID: " + username);
ac.setAuthorizedID(username);
}
}
}
}
/** CallbackHandler for SASL GSSAPI Kerberos mechanism */
public static class SaslGssCallbackHandler implements CallbackHandler {
/** {@inheritDoc} */
@Override
public void handle(Callback[] callbacks) throws IOException,
UnsupportedCallbackException {
AuthorizeCallback ac = null;
for (Callback callback : callbacks) {
if (callback instanceof AuthorizeCallback) {
ac = (AuthorizeCallback) callback;
} else {
throw new UnsupportedCallbackException(callback,
"Unrecognized SASL GSSAPI Callback");
}
}
if (ac != null) {
String authid = ac.getAuthenticationID();
String authzid = ac.getAuthorizationID();
if (authid.equals(authzid)) {
ac.setAuthorized(true);
} else {
ac.setAuthorized(false);
}
if (ac.isAuthorized()) {
if (LOG.isDebugEnabled())
LOG.debug("SASL server GSSAPI callback: setting "
+ "canonicalized client ID: " + authzid);
ac.setAuthorizedID(authzid);
}
}
}
}
}

View File

@ -61,6 +61,12 @@ public abstract class SecretManager<T extends TokenIdentifier> {
*/ */
public abstract byte[] retrievePassword(T identifier) throws InvalidToken; public abstract byte[] retrievePassword(T identifier) throws InvalidToken;
/**
* Create an empty token identifier.
* @return the newly created empty token identifier
*/
public abstract T createIdentifier();
/** /**
* The name of the hashing algorithm. * The name of the hashing algorithm.
*/ */

View File

@ -35,6 +35,12 @@ public abstract class TokenIdentifier implements Writable {
* @return the kind of the token * @return the kind of the token
*/ */
public abstract Text getKind(); public abstract Text getKind();
/**
* Get the username encoded in the token identifier
* @return the username
*/
public abstract Text getUsername();
/** /**
* Get the bytes for the token identifier * Get the bytes for the token identifier

View File

@ -0,0 +1,31 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security.token;
import java.lang.annotation.*;
/**
* Indicates Token related information to be used
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface TokenInfo {
/** The type of TokenSelector to be used */
Class<? extends TokenSelector<? extends TokenIdentifier>> value();
}

View File

@ -0,0 +1,34 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security.token;
import java.util.Collection;
import org.apache.hadoop.io.Text;
/**
* Select token of type T from tokens for use with named service
*
* @param <T>
* T extends TokenIdentifier
*/
public interface TokenSelector<T extends TokenIdentifier> {
Token<T> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens);
}

View File

@ -63,4 +63,10 @@
<description>The name of the s3n file system for testing.</description> <description>The name of the s3n file system for testing.</description>
</property> </property>
<!-- Turn security off for tests by default -->
<property>
<name>hadoop.security.authentication</name>
<value>simple</value>
</property>
</configuration> </configuration>

View File

@ -68,7 +68,7 @@ public class TestRPC extends TestCase {
int[] exchange(int[] values) throws IOException; int[] exchange(int[] values) throws IOException;
} }
public class TestImpl implements TestProtocol { public static class TestImpl implements TestProtocol {
int fastPingCounter = 0; int fastPingCounter = 0;
public long getProtocolVersion(String protocol, long clientVersion) { public long getProtocolVersion(String protocol, long clientVersion) {
@ -189,7 +189,7 @@ public class TestRPC extends TestCase {
System.out.println("Testing Slow RPC"); System.out.println("Testing Slow RPC");
// create a server with two handlers // create a server with two handlers
Server server = RPC.getServer(TestProtocol.class, Server server = RPC.getServer(TestProtocol.class,
new TestImpl(), ADDRESS, 0, 2, false, conf); new TestImpl(), ADDRESS, 0, 2, false, conf, null);
TestProtocol proxy = null; TestProtocol proxy = null;
try { try {
@ -339,7 +339,7 @@ public class TestRPC extends TestCase {
ServiceAuthorizationManager.refresh(conf, new TestPolicyProvider()); ServiceAuthorizationManager.refresh(conf, new TestPolicyProvider());
Server server = RPC.getServer(TestProtocol.class, Server server = RPC.getServer(TestProtocol.class,
new TestImpl(), ADDRESS, 0, 5, true, conf); new TestImpl(), ADDRESS, 0, 5, true, conf, null);
TestProtocol proxy = null; TestProtocol proxy = null;

View File

@ -0,0 +1,216 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collection;
import org.apache.commons.logging.*;
import org.apache.commons.logging.impl.Log4JLogger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.security.SaslInputStream;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.log4j.Level;
import org.junit.Test;
/** Unit tests for using Sasl over RPC. */
public class TestSaslRPC {
private static final String ADDRESS = "0.0.0.0";
public static final Log LOG =
LogFactory.getLog(TestSaslRPC.class);
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
private static Configuration conf;
static {
conf = new Configuration();
conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos");
UserGroupInformation.setConfiguration(conf);
}
static {
((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL);
((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL);
}
public static class TestTokenIdentifier extends TokenIdentifier {
private Text tokenid;
final static Text KIND_NAME = new Text("test.token");
public TestTokenIdentifier() {
this.tokenid = new Text();
}
public TestTokenIdentifier(Text tokenid) {
this.tokenid = tokenid;
}
@Override
public Text getKind() {
return KIND_NAME;
}
@Override
public Text getUsername() {
return tokenid;
}
@Override
public void readFields(DataInput in) throws IOException {
tokenid.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
tokenid.write(out);
}
}
public static class TestTokenSecretManager extends
SecretManager<TestTokenIdentifier> {
public byte[] createPassword(TestTokenIdentifier id) {
return id.getBytes();
}
public byte[] retrievePassword(TestTokenIdentifier id)
throws InvalidToken {
return id.getBytes();
}
public TestTokenIdentifier createIdentifier() {
return new TestTokenIdentifier();
}
}
public static class TestTokenSelector implements
TokenSelector<TestTokenIdentifier> {
@SuppressWarnings("unchecked")
@Override
public Token<TestTokenIdentifier> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens) {
if (service == null) {
return null;
}
for (Token<? extends TokenIdentifier> token : tokens) {
if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
&& service.equals(token.getService())) {
return (Token<TestTokenIdentifier>) token;
}
}
return null;
}
}
@KerberosInfo(SERVER_PRINCIPAL_KEY)
@TokenInfo(TestTokenSelector.class)
public interface TestSaslProtocol extends TestRPC.TestProtocol {
}
public static class TestSaslImpl extends TestRPC.TestImpl implements
TestSaslProtocol {
}
@Test
public void testDigestRpc() throws Exception {
TestTokenSecretManager sm = new TestTokenSecretManager();
final Server server = RPC.getServer(TestSaslProtocol.class,
new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
server.start();
final UserGroupInformation current = UserGroupInformation.getCurrentUser();
final InetSocketAddress addr = NetUtils.getConnectAddress(server);
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
.getUserName()));
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
sm);
Text host = new Text(addr.getAddress().getHostAddress() + ":"
+ addr.getPort());
token.setService(host);
LOG.info("Service IP address for token is " + host);
current.addToken(token);
TestSaslProtocol proxy = null;
try {
proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, conf);
proxy.ping();
} finally {
server.stop();
if (proxy != null) {
RPC.stopProxy(proxy);
}
}
}
static void testKerberosRpc(String principal, String keytab) throws Exception {
final Configuration newConf = new Configuration(conf);
newConf.set(SERVER_PRINCIPAL_KEY, principal);
UserGroupInformation.loginUserFromKeytab(principal, keytab);
UserGroupInformation current = UserGroupInformation.getCurrentUser();
System.out.println("UGI: " + current);
Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(),
ADDRESS, 0, 5, true, newConf, null);
TestSaslProtocol proxy = null;
server.start();
InetSocketAddress addr = NetUtils.getConnectAddress(server);
try {
proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
TestSaslProtocol.versionID, addr, newConf);
proxy.ping();
} finally {
server.stop();
if (proxy != null) {
RPC.stopProxy(proxy);
}
}
}
public static void main(String[] args) throws Exception {
System.out.println("Testing Kerberos authentication over RPC");
if (args.length != 2) {
System.err
.println("Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC "
+ " <serverPrincipal> <keytabFile>");
System.exit(-1);
}
String principal = args[0];
String keytab = args[1];
testKerberosRpc(principal, keytab);
}
}