HADOOP-12819. Migrate TestSaslRPC and related codes to rebase on ProtobufRpcEngine. Contributed by Kai Zheng.

This commit is contained in:
Haohui Mai 2016-03-20 17:40:59 -07:00
parent 6236782151
commit 478a25b929
6 changed files with 305 additions and 303 deletions

View File

@ -29,6 +29,22 @@ import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
import org.junit.Assert; import org.junit.Assert;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.protobuf.TestProtos;
import org.apache.hadoop.ipc.protobuf.TestRpcServiceProtos;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.junit.Assert;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException; import java.io.IOException;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo; import java.lang.management.ThreadInfo;
@ -37,6 +53,8 @@ import java.net.InetSocketAddress;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -149,6 +167,89 @@ public class TestRpcBase {
return count; return count;
} }
public static class TestTokenIdentifier extends TokenIdentifier {
private Text tokenid;
private Text realUser;
final static Text KIND_NAME = new Text("test.token");
public TestTokenIdentifier() {
this(new Text(), new Text());
}
public TestTokenIdentifier(Text tokenid) {
this(tokenid, new Text());
}
public TestTokenIdentifier(Text tokenid, Text realUser) {
this.tokenid = tokenid == null ? new Text() : tokenid;
this.realUser = realUser == null ? new Text() : realUser;
}
@Override
public Text getKind() {
return KIND_NAME;
}
@Override
public UserGroupInformation getUser() {
if (realUser.toString().isEmpty()) {
return UserGroupInformation.createRemoteUser(tokenid.toString());
} else {
UserGroupInformation realUgi = UserGroupInformation
.createRemoteUser(realUser.toString());
return UserGroupInformation
.createProxyUser(tokenid.toString(), realUgi);
}
}
@Override
public void readFields(DataInput in) throws IOException {
tokenid.readFields(in);
realUser.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
tokenid.write(out);
realUser.write(out);
}
}
public static class TestTokenSecretManager extends
SecretManager<TestTokenIdentifier> {
@Override
public byte[] createPassword(TestTokenIdentifier id) {
return id.getBytes();
}
@Override
public byte[] retrievePassword(TestTokenIdentifier id)
throws InvalidToken {
return id.getBytes();
}
@Override
public TestTokenIdentifier createIdentifier() {
return new TestTokenIdentifier();
}
}
public static class TestTokenSelector implements
TokenSelector<TestTokenIdentifier> {
@SuppressWarnings("unchecked")
@Override
public Token<TestTokenIdentifier> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens) {
if (service == null) {
return null;
}
for (Token<? extends TokenIdentifier> token : tokens) {
if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
&& service.equals(token.getService())) {
return (Token<TestTokenIdentifier>) token;
}
}
return null;
}
}
@KerberosInfo(serverPrincipal = SERVER_PRINCIPAL_KEY)
@TokenInfo(TestTokenSelector.class)
@ProtocolInfo(protocolName = "org.apache.hadoop.ipc.TestRpcBase$TestRpcService", @ProtocolInfo(protocolName = "org.apache.hadoop.ipc.TestRpcBase$TestRpcService",
protocolVersion = 1) protocolVersion = 1)
public interface TestRpcService public interface TestRpcService
@ -267,12 +368,80 @@ public class TestRpcBase {
} catch (InterruptedException ignore) {} } catch (InterruptedException ignore) {}
return TestProtos.EmptyResponseProto.newBuilder().build(); return TestProtos.EmptyResponseProto.newBuilder().build();
} }
@Override
public TestProtos.AuthMethodResponseProto getAuthMethod(
RpcController controller, TestProtos.EmptyRequestProto request)
throws ServiceException {
AuthMethod authMethod = null;
try {
authMethod = UserGroupInformation.getCurrentUser()
.getAuthenticationMethod().getAuthMethod();
} catch (IOException e) {
throw new ServiceException(e);
}
return TestProtos.AuthMethodResponseProto.newBuilder()
.setCode(authMethod.code)
.setMechanismName(authMethod.getMechanismName())
.build();
}
@Override
public TestProtos.AuthUserResponseProto getAuthUser(
RpcController controller, TestProtos.EmptyRequestProto request)
throws ServiceException {
UserGroupInformation authUser = null;
try {
authUser = UserGroupInformation.getCurrentUser();
} catch (IOException e) {
throw new ServiceException(e);
}
return TestProtos.AuthUserResponseProto.newBuilder()
.setAuthUser(authUser.getUserName())
.build();
}
@Override
public TestProtos.EchoResponseProto echoPostponed(
RpcController controller, TestProtos.EchoRequestProto request)
throws ServiceException {
Server.Call call = Server.getCurCall().get();
call.postponeResponse();
postponedCalls.add(call);
return TestProtos.EchoResponseProto.newBuilder().setMessage(
request.getMessage())
.build();
}
@Override
public TestProtos.EmptyResponseProto sendPostponed(
RpcController controller, TestProtos.EmptyRequestProto request)
throws ServiceException {
Collections.shuffle(postponedCalls);
try {
for (Server.Call call : postponedCalls) {
call.sendResponse();
}
} catch (IOException e) {
throw new ServiceException(e);
}
postponedCalls.clear();
return TestProtos.EmptyResponseProto.newBuilder().build();
}
} }
protected static TestProtos.EmptyRequestProto newEmptyRequest() { protected static TestProtos.EmptyRequestProto newEmptyRequest() {
return TestProtos.EmptyRequestProto.newBuilder().build(); return TestProtos.EmptyRequestProto.newBuilder().build();
} }
protected static TestProtos.EmptyResponseProto newEmptyResponse() {
return TestProtos.EmptyResponseProto.newBuilder().build();
}
protected static TestProtos.EchoRequestProto newEchoRequest(String msg) { protected static TestProtos.EchoRequestProto newEchoRequest(String msg) {
return TestProtos.EchoRequestProto.newBuilder().setMessage(msg).build(); return TestProtos.EchoRequestProto.newBuilder().setMessage(msg).build();
} }
@ -292,4 +461,25 @@ public class TestRpcBase {
return TestProtos.SleepRequestProto.newBuilder() return TestProtos.SleepRequestProto.newBuilder()
.setMilliSeconds(milliSeconds).build(); .setMilliSeconds(milliSeconds).build();
} }
protected static TestProtos.EchoResponseProto newEchoResponse(String msg) {
return TestProtos.EchoResponseProto.newBuilder().setMessage(msg).build();
}
protected static AuthMethod convert(
TestProtos.AuthMethodResponseProto authMethodResponse) {
String mechanism = authMethodResponse.getMechanismName();
if (mechanism.equals(AuthMethod.SIMPLE.getMechanismName())) {
return AuthMethod.SIMPLE;
} else if (mechanism.equals(AuthMethod.KERBEROS.getMechanismName())) {
return AuthMethod.KERBEROS;
} else if (mechanism.equals(AuthMethod.TOKEN.getMechanismName())) {
return AuthMethod.TOKEN;
}
return null;
}
protected static String convert(TestProtos.AuthUserResponseProto response) {
return response.getAuthUser();
}
} }

View File

@ -18,53 +18,7 @@
package org.apache.hadoop.ipc; package org.apache.hadoop.ipc;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_RPC_PROTECTION; import com.google.protobuf.ServiceException;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.KERBEROS;
import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.SIMPLE;
import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.TOKEN;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.security.Security;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -74,27 +28,13 @@ import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.Server.Call;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo; import org.apache.hadoop.security.*;
import org.apache.hadoop.security.SaslInputStream;
import org.apache.hadoop.security.SaslPlainServer;
import org.apache.hadoop.security.SaslPropertiesResolver;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.QualityOfProtection; import org.apache.hadoop.security.SaslRpcServer.QualityOfProtection;
import org.apache.hadoop.security.SecurityInfo;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.TestUserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.*;
import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.log4j.Level; import org.apache.log4j.Level;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
@ -104,9 +44,27 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters; import org.junit.runners.Parameterized.Parameters;
import javax.security.auth.callback.*;
import javax.security.sasl.*;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.security.Security;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_RPC_PROTECTION;
import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.*;
import static org.junit.Assert.*;
/** Unit tests for using Sasl over RPC. */ /** Unit tests for using Sasl over RPC. */
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class TestSaslRPC { public class TestSaslRPC extends TestRpcBase {
@Parameters @Parameters
public static Collection<Object[]> data() { public static Collection<Object[]> data() {
Collection<Object[]> params = new ArrayList<Object[]>(); Collection<Object[]> params = new ArrayList<Object[]>();
@ -136,17 +94,13 @@ public class TestSaslRPC {
this.saslPropertiesResolver = saslPropertiesResolver; this.saslPropertiesResolver = saslPropertiesResolver;
} }
private static final String ADDRESS = "0.0.0.0";
public static final Log LOG = public static final Log LOG =
LogFactory.getLog(TestSaslRPC.class); LogFactory.getLog(TestSaslRPC.class);
static final String ERROR_MESSAGE = "Token is invalid"; static final String ERROR_MESSAGE = "Token is invalid";
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab"; static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR"; static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
private static Configuration conf;
// If this is set to true AND the auth-method is not simple, secretManager // If this is set to true AND the auth-method is not simple, secretManager
// will be enabled. // will be enabled.
static Boolean enableSecretManager = null; static Boolean enableSecretManager = null;
@ -155,7 +109,7 @@ public class TestSaslRPC {
static Boolean forceSecretManager = null; static Boolean forceSecretManager = null;
static Boolean clientFallBackToSimpleAllowed = true; static Boolean clientFallBackToSimpleAllowed = true;
static enum UseToken { enum UseToken {
NONE(), NONE(),
VALID(), VALID(),
INVALID(), INVALID(),
@ -174,6 +128,7 @@ public class TestSaslRPC {
LOG.info("---------------------------------"); LOG.info("---------------------------------");
LOG.info("Testing QOP:"+ getQOPNames(qop)); LOG.info("Testing QOP:"+ getQOPNames(qop));
LOG.info("---------------------------------"); LOG.info("---------------------------------");
conf = new Configuration(); conf = new Configuration();
// the specific tests for kerberos will enable kerberos. forcing it // the specific tests for kerberos will enable kerberos. forcing it
// for all tests will cause tests to fail if the user has a TGT // for all tests will cause tests to fail if the user has a TGT
@ -187,6 +142,9 @@ public class TestSaslRPC {
enableSecretManager = null; enableSecretManager = null;
forceSecretManager = null; forceSecretManager = null;
clientFallBackToSimpleAllowed = true; clientFallBackToSimpleAllowed = true;
// Set RPC engine to protobuf RPC engine
RPC.setProtocolEngine(conf, TestRpcService.class, ProtobufRpcEngine.class);
} }
static String getQOPNames (QualityOfProtection[] qops){ static String getQOPNames (QualityOfProtection[] qops){
@ -210,68 +168,6 @@ public class TestSaslRPC {
((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL);
} }
public static class TestTokenIdentifier extends TokenIdentifier {
private Text tokenid;
private Text realUser;
final static Text KIND_NAME = new Text("test.token");
public TestTokenIdentifier() {
this(new Text(), new Text());
}
public TestTokenIdentifier(Text tokenid) {
this(tokenid, new Text());
}
public TestTokenIdentifier(Text tokenid, Text realUser) {
this.tokenid = tokenid == null ? new Text() : tokenid;
this.realUser = realUser == null ? new Text() : realUser;
}
@Override
public Text getKind() {
return KIND_NAME;
}
@Override
public UserGroupInformation getUser() {
if (realUser.toString().isEmpty()) {
return UserGroupInformation.createRemoteUser(tokenid.toString());
} else {
UserGroupInformation realUgi = UserGroupInformation
.createRemoteUser(realUser.toString());
return UserGroupInformation
.createProxyUser(tokenid.toString(), realUgi);
}
}
@Override
public void readFields(DataInput in) throws IOException {
tokenid.readFields(in);
realUser.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
tokenid.write(out);
realUser.write(out);
}
}
public static class TestTokenSecretManager extends
SecretManager<TestTokenIdentifier> {
@Override
public byte[] createPassword(TestTokenIdentifier id) {
return id.getBytes();
}
@Override
public byte[] retrievePassword(TestTokenIdentifier id)
throws InvalidToken {
return id.getBytes();
}
@Override
public TestTokenIdentifier createIdentifier() {
return new TestTokenIdentifier();
}
}
public static class BadTokenSecretManager extends TestTokenSecretManager { public static class BadTokenSecretManager extends TestTokenSecretManager {
@Override @Override
@ -281,64 +177,6 @@ public class TestSaslRPC {
} }
} }
public static class TestTokenSelector implements
TokenSelector<TestTokenIdentifier> {
@SuppressWarnings("unchecked")
@Override
public Token<TestTokenIdentifier> selectToken(Text service,
Collection<Token<? extends TokenIdentifier>> tokens) {
if (service == null) {
return null;
}
for (Token<? extends TokenIdentifier> token : tokens) {
if (TestTokenIdentifier.KIND_NAME.equals(token.getKind())
&& service.equals(token.getService())) {
return (Token<TestTokenIdentifier>) token;
}
}
return null;
}
}
@KerberosInfo(
serverPrincipal = SERVER_PRINCIPAL_KEY)
@TokenInfo(TestTokenSelector.class)
public interface TestSaslProtocol extends TestRPC.TestProtocol {
public AuthMethod getAuthMethod() throws IOException;
public String getAuthUser() throws IOException;
public String echoPostponed(String value) throws IOException;
public void sendPostponed() throws IOException;
}
public static class TestSaslImpl extends TestRPC.TestImpl implements
TestSaslProtocol {
private List<Call> postponedCalls = new ArrayList<Call>();
@Override
public AuthMethod getAuthMethod() throws IOException {
return UserGroupInformation.getCurrentUser()
.getAuthenticationMethod().getAuthMethod();
}
@Override
public String getAuthUser() throws IOException {
return UserGroupInformation.getCurrentUser().getUserName();
}
@Override
public String echoPostponed(String value) {
Call call = Server.getCurCall().get();
call.postponeResponse();
postponedCalls.add(call);
return value;
}
@Override
public void sendPostponed() throws IOException {
Collections.shuffle(postponedCalls);
for (Call call : postponedCalls) {
call.sendResponse();
}
postponedCalls.clear();
}
}
public static class CustomSecurityInfo extends SecurityInfo { public static class CustomSecurityInfo extends SecurityInfo {
@Override @Override
@ -378,10 +216,7 @@ public class TestSaslRPC {
@Test @Test
public void testDigestRpc() throws Exception { public void testDigestRpc() throws Exception {
TestTokenSecretManager sm = new TestTokenSecretManager(); TestTokenSecretManager sm = new TestTokenSecretManager();
final Server server = new RPC.Builder(conf) final Server server = setupTestServer(conf, 5, sm);
.setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
.setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
.setSecretManager(sm).build();
doDigestRpc(server, sm); doDigestRpc(server, sm);
} }
@ -391,10 +226,7 @@ public class TestSaslRPC {
TestTokenSecretManager sm = new TestTokenSecretManager(); TestTokenSecretManager sm = new TestTokenSecretManager();
try { try {
SecurityUtil.setSecurityInfoProviders(new CustomSecurityInfo()); SecurityUtil.setSecurityInfoProviders(new CustomSecurityInfo());
final Server server = new RPC.Builder(conf) final Server server = setupTestServer(conf, 5, sm);
.setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
.setBindAddress(ADDRESS).setPort(0).setNumHandlers(5)
.setVerbose(true).setSecretManager(sm).build();
doDigestRpc(server, sm); doDigestRpc(server, sm);
} finally { } finally {
SecurityUtil.setSecurityInfoProviders(new SecurityInfo[0]); SecurityUtil.setSecurityInfoProviders(new SecurityInfo[0]);
@ -404,58 +236,46 @@ public class TestSaslRPC {
@Test @Test
public void testErrorMessage() throws Exception { public void testErrorMessage() throws Exception {
BadTokenSecretManager sm = new BadTokenSecretManager(); BadTokenSecretManager sm = new BadTokenSecretManager();
final Server server = new RPC.Builder(conf) final Server server = setupTestServer(conf, 5, sm);
.setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
.setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
.setSecretManager(sm).build();
boolean succeeded = false; boolean succeeded = false;
try { try {
doDigestRpc(server, sm); doDigestRpc(server, sm);
} catch (RemoteException e) { } catch (ServiceException e) {
LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); assertTrue(e.getCause() instanceof RemoteException);
assertEquals(ERROR_MESSAGE, e.getLocalizedMessage()); RemoteException re = (RemoteException) e.getCause();
assertTrue(e.unwrapRemoteException() instanceof InvalidToken); LOG.info("LOGGING MESSAGE: " + re.getLocalizedMessage());
assertEquals(ERROR_MESSAGE, re.getLocalizedMessage());
assertTrue(re.unwrapRemoteException() instanceof InvalidToken);
succeeded = true; succeeded = true;
} }
assertTrue(succeeded); assertTrue(succeeded);
} }
private void doDigestRpc(Server server, TestTokenSecretManager sm private void doDigestRpc(Server server, TestTokenSecretManager sm)
) throws Exception { throws Exception {
server.start();
final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final UserGroupInformation current = UserGroupInformation.getCurrentUser();
final InetSocketAddress addr = NetUtils.getConnectAddress(server); addr = NetUtils.getConnectAddress(server);
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
.getUserName())); .getUserName()));
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm);
sm);
SecurityUtil.setTokenService(token, addr); SecurityUtil.setTokenService(token, addr);
current.addToken(token); current.addToken(token);
TestSaslProtocol proxy = null; TestRpcService proxy = null;
try { try {
proxy = RPC.getProxy(TestSaslProtocol.class, proxy = getClient(addr, conf);
TestSaslProtocol.versionID, addr, conf); AuthMethod authMethod = convert(
AuthMethod authMethod = proxy.getAuthMethod(); proxy.getAuthMethod(null, newEmptyRequest()));
assertEquals(TOKEN, authMethod); assertEquals(TOKEN, authMethod);
//QOP must be auth //QOP must be auth
assertEquals(expectedQop.saslQop, assertEquals(expectedQop.saslQop,
RPC.getConnectionIdForProxy(proxy).getSaslQop()); RPC.getConnectionIdForProxy(proxy).getSaslQop());
proxy.ping(); proxy.ping(null, newEmptyRequest());
} finally { } finally {
server.stop(); stop(server, proxy);
if (proxy != null) {
RPC.stopProxy(proxy);
} }
} }
}
static ConnectionId getConnectionId(Configuration conf) throws IOException {
return ConnectionId.getConnectionId(new InetSocketAddress(0),
TestSaslProtocol.class, null, 0, null, conf);
}
@Test @Test
public void testPingInterval() throws Exception { public void testPingInterval() throws Exception {
@ -466,29 +286,26 @@ public class TestSaslRPC {
// set doPing to true // set doPing to true
newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true); newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true);
ConnectionId remoteId = getConnectionId(newConf); ConnectionId remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0),
TestRpcService.class, null, 0, null, newConf);
assertEquals(CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT, assertEquals(CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT,
remoteId.getPingInterval()); remoteId.getPingInterval());
// set doPing to false // set doPing to false
newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false); newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false);
remoteId = getConnectionId(newConf); remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0),
TestRpcService.class, null, 0, null, newConf);
assertEquals(0, remoteId.getPingInterval()); assertEquals(0, remoteId.getPingInterval());
} }
@Test @Test
public void testPerConnectionConf() throws Exception { public void testPerConnectionConf() throws Exception {
TestTokenSecretManager sm = new TestTokenSecretManager(); TestTokenSecretManager sm = new TestTokenSecretManager();
final Server server = new RPC.Builder(conf) final Server server = setupTestServer(conf, 5, sm);
.setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl())
.setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
.setSecretManager(sm).build();
server.start();
final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final UserGroupInformation current = UserGroupInformation.getCurrentUser();
final InetSocketAddress addr = NetUtils.getConnectAddress(server); final InetSocketAddress addr = NetUtils.getConnectAddress(server);
TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
.getUserName())); .getUserName()));
Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, Token<TestTokenIdentifier> token = new Token<>(tokenId, sm);
sm);
SecurityUtil.setTokenService(token, addr); SecurityUtil.setTokenService(token, addr);
current.addToken(token); current.addToken(token);
@ -497,28 +314,25 @@ public class TestSaslRPC {
HADOOP_RPC_SOCKET_FACTORY_CLASS_DEFAULT_KEY, ""); HADOOP_RPC_SOCKET_FACTORY_CLASS_DEFAULT_KEY, "");
Client client = null; Client client = null;
TestSaslProtocol proxy1 = null; TestRpcService proxy1 = null;
TestSaslProtocol proxy2 = null; TestRpcService proxy2 = null;
TestSaslProtocol proxy3 = null; TestRpcService proxy3 = null;
int timeouts[] = {111222, 3333333}; int timeouts[] = {111222, 3333333};
try { try {
newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[0]); newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[0]);
proxy1 = RPC.getProxy(TestSaslProtocol.class, proxy1 = getClient(addr, newConf);
TestSaslProtocol.versionID, addr, newConf); proxy1.getAuthMethod(null, newEmptyRequest());
proxy1.getAuthMethod(); client = ProtobufRpcEngine.getClient(newConf);
client = WritableRpcEngine.getClient(newConf);
Set<ConnectionId> conns = client.getConnectionIds(); Set<ConnectionId> conns = client.getConnectionIds();
assertEquals("number of connections in cache is wrong", 1, conns.size()); assertEquals("number of connections in cache is wrong", 1, conns.size());
// same conf, connection should be re-used // same conf, connection should be re-used
proxy2 = RPC.getProxy(TestSaslProtocol.class, proxy2 = getClient(addr, newConf);
TestSaslProtocol.versionID, addr, newConf); proxy2.getAuthMethod(null, newEmptyRequest());
proxy2.getAuthMethod();
assertEquals("number of connections in cache is wrong", 1, conns.size()); assertEquals("number of connections in cache is wrong", 1, conns.size());
// different conf, new connection should be set up // different conf, new connection should be set up
newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[1]); newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[1]);
proxy3 = RPC.getProxy(TestSaslProtocol.class, proxy3 = getClient(addr, newConf);
TestSaslProtocol.versionID, addr, newConf); proxy3.getAuthMethod(null, newEmptyRequest());
proxy3.getAuthMethod();
assertEquals("number of connections in cache is wrong", 2, conns.size()); assertEquals("number of connections in cache is wrong", 2, conns.size());
// now verify the proxies have the correct connection ids and timeouts // now verify the proxies have the correct connection ids and timeouts
ConnectionId[] connsArray = { ConnectionId[] connsArray = {
@ -551,24 +365,14 @@ public class TestSaslRPC {
UserGroupInformation current = UserGroupInformation.getCurrentUser(); UserGroupInformation current = UserGroupInformation.getCurrentUser();
System.out.println("UGI: " + current); System.out.println("UGI: " + current);
Server server = new RPC.Builder(newConf) Server server = setupTestServer(newConf, 5);
.setProtocol(TestSaslProtocol.class).setInstance(new TestSaslImpl()) TestRpcService proxy = null;
.setBindAddress(ADDRESS).setPort(0).setNumHandlers(5).setVerbose(true)
.build();
TestSaslProtocol proxy = null;
server.start();
InetSocketAddress addr = NetUtils.getConnectAddress(server);
try { try {
proxy = RPC.getProxy(TestSaslProtocol.class, proxy = getClient(addr, newConf);
TestSaslProtocol.versionID, addr, newConf); proxy.ping(null, newEmptyRequest());
proxy.ping();
} finally { } finally {
server.stop(); stop(server, proxy);
if (proxy != null) {
RPC.stopProxy(proxy);
}
} }
System.out.println("Test is successful."); System.out.println("Test is successful.");
} }
@ -887,14 +691,7 @@ public class TestSaslRPC {
UserGroupInformation.setConfiguration(conf); UserGroupInformation.setConfiguration(conf);
TestTokenSecretManager sm = new TestTokenSecretManager(); TestTokenSecretManager sm = new TestTokenSecretManager();
Server server = new RPC.Builder(conf) Server server = setupTestServer(conf, 1, sm);
.setProtocol(TestSaslProtocol.class)
.setInstance(new TestSaslImpl()).setBindAddress(ADDRESS).setPort(0)
.setNumHandlers(1) // prevents ordering issues when unblocking calls.
.setVerbose(true)
.setSecretManager(sm)
.build();
server.start();
try { try {
final InetSocketAddress addr = NetUtils.getConnectAddress(server); final InetSocketAddress addr = NetUtils.getConnectAddress(server);
final UserGroupInformation clientUgi = final UserGroupInformation clientUgi =
@ -903,14 +700,13 @@ public class TestSaslRPC {
TestTokenIdentifier tokenId = new TestTokenIdentifier( TestTokenIdentifier tokenId = new TestTokenIdentifier(
new Text(clientUgi.getUserName())); new Text(clientUgi.getUserName()));
Token<?> token = new Token<TestTokenIdentifier>(tokenId, sm); Token<?> token = new Token<>(tokenId, sm);
SecurityUtil.setTokenService(token, addr); SecurityUtil.setTokenService(token, addr);
clientUgi.addToken(token); clientUgi.addToken(token);
clientUgi.doAs(new PrivilegedExceptionAction<Void>() { clientUgi.doAs(new PrivilegedExceptionAction<Void>() {
@Override @Override
public Void run() throws Exception { public Void run() throws Exception {
final TestSaslProtocol proxy = RPC.getProxy(TestSaslProtocol.class, final TestRpcService proxy = getClient(addr, conf);
TestSaslProtocol.versionID, addr, conf);
final ExecutorService executor = Executors.newCachedThreadPool(); final ExecutorService executor = Executors.newCachedThreadPool();
final AtomicInteger count = new AtomicInteger(); final AtomicInteger count = new AtomicInteger();
try { try {
@ -922,7 +718,8 @@ public class TestSaslRPC {
@Override @Override
public Void call() throws Exception { public Void call() throws Exception {
String expect = "future"+count.getAndIncrement(); String expect = "future"+count.getAndIncrement();
String answer = proxy.echoPostponed(expect); String answer = convert(proxy.echoPostponed(null,
newEchoRequest(expect)));
assertEquals(expect, answer); assertEquals(expect, answer);
return null; return null;
} }
@ -939,7 +736,7 @@ public class TestSaslRPC {
// only 1 handler ensures that the prior calls are already // only 1 handler ensures that the prior calls are already
// postponed. 1 handler also ensures that this call will // postponed. 1 handler also ensures that this call will
// timeout if the postponing doesn't work (ie. free up handler) // timeout if the postponing doesn't work (ie. free up handler)
proxy.sendPostponed(); proxy.sendPostponed(null, newEmptyRequest());
for (int i=0; i < futures.length; i++) { for (int i=0; i < futures.length; i++) {
LOG.info("waiting for future"+i); LOG.info("waiting for future"+i);
futures[i].get(); futures[i].get();
@ -1009,14 +806,7 @@ public class TestSaslRPC {
Server server = serverUgi.doAs(new PrivilegedExceptionAction<Server>() { Server server = serverUgi.doAs(new PrivilegedExceptionAction<Server>() {
@Override @Override
public Server run() throws IOException { public Server run() throws IOException {
Server server = new RPC.Builder(serverConf) return setupTestServer(serverConf, 5, serverSm);
.setProtocol(TestSaslProtocol.class)
.setInstance(new TestSaslImpl()).setBindAddress(ADDRESS).setPort(0)
.setNumHandlers(5).setVerbose(true)
.setSecretManager(serverSm)
.build();
server.start();
return server;
} }
}); });
@ -1038,17 +828,17 @@ public class TestSaslRPC {
Token<TestTokenIdentifier> token = null; Token<TestTokenIdentifier> token = null;
switch (tokenType) { switch (tokenType) {
case VALID: case VALID:
token = new Token<TestTokenIdentifier>(tokenId, sm); token = new Token<>(tokenId, sm);
SecurityUtil.setTokenService(token, addr); SecurityUtil.setTokenService(token, addr);
break; break;
case INVALID: case INVALID:
token = new Token<TestTokenIdentifier>( token = new Token<>(
tokenId.getBytes(), "bad-password!".getBytes(), tokenId.getBytes(), "bad-password!".getBytes(),
tokenId.getKind(), null); tokenId.getKind(), null);
SecurityUtil.setTokenService(token, addr); SecurityUtil.setTokenService(token, addr);
break; break;
case OTHER: case OTHER:
token = new Token<TestTokenIdentifier>(); token = new Token<>();
break; break;
case NONE: // won't get here case NONE: // won't get here
} }
@ -1060,19 +850,28 @@ public class TestSaslRPC {
return clientUgi.doAs(new PrivilegedExceptionAction<String>() { return clientUgi.doAs(new PrivilegedExceptionAction<String>() {
@Override @Override
public String run() throws IOException { public String run() throws IOException {
TestSaslProtocol proxy = null; TestRpcService proxy = null;
try { try {
proxy = RPC.getProxy(TestSaslProtocol.class, proxy = getClient(addr, clientConf);
TestSaslProtocol.versionID, addr, clientConf);
proxy.ping(); proxy.ping(null, newEmptyRequest());
// make sure the other side thinks we are who we said we are!!! // make sure the other side thinks we are who we said we are!!!
assertEquals(clientUgi.getUserName(), proxy.getAuthUser()); assertEquals(clientUgi.getUserName(),
AuthMethod authMethod = proxy.getAuthMethod(); convert(proxy.getAuthUser(null, newEmptyRequest())));
AuthMethod authMethod =
convert(proxy.getAuthMethod(null, newEmptyRequest()));
// verify sasl completed with correct QOP // verify sasl completed with correct QOP
assertEquals((authMethod != SIMPLE) ? expectedQop.saslQop : null, assertEquals((authMethod != SIMPLE) ? expectedQop.saslQop : null,
RPC.getConnectionIdForProxy(proxy).getSaslQop()); RPC.getConnectionIdForProxy(proxy).getSaslQop());
return authMethod.toString(); return authMethod.toString();
} catch (ServiceException se) {
if (se.getCause() instanceof RemoteException) {
throw (RemoteException) se.getCause();
} else if (se.getCause() instanceof IOException) {
throw (IOException) se.getCause();
} else {
throw new RuntimeException(se.getCause());
}
} finally { } finally {
if (proxy != null) { if (proxy != null) {
RPC.stopProxy(proxy); RPC.stopProxy(proxy);

View File

@ -41,9 +41,9 @@ import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenInfo; import org.apache.hadoop.security.token.TokenInfo;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.hadoop.ipc.TestSaslRPC.TestTokenSecretManager; import org.apache.hadoop.ipc.TestRpcBase.TestTokenSecretManager;
import org.apache.hadoop.ipc.TestSaslRPC.TestTokenIdentifier; import org.apache.hadoop.ipc.TestRpcBase.TestTokenIdentifier;
import org.apache.hadoop.ipc.TestSaslRPC.TestTokenSelector; import org.apache.hadoop.ipc.TestRpcBase.TestTokenSelector;
import org.apache.commons.logging.*; import org.apache.commons.logging.*;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;

View File

@ -82,3 +82,12 @@ message ExchangeRequestProto {
message ExchangeResponseProto { message ExchangeResponseProto {
repeated int32 values = 1; repeated int32 values = 1;
} }
message AuthMethodResponseProto {
required int32 code = 1;
required string mechanismName = 2;
}
message AuthUserResponseProto {
required string authUser = 1;
}

View File

@ -39,6 +39,10 @@ service TestProtobufRpcProto {
rpc testServerGet(EmptyRequestProto) returns (EmptyResponseProto); rpc testServerGet(EmptyRequestProto) returns (EmptyResponseProto);
rpc exchange(ExchangeRequestProto) returns (ExchangeResponseProto); rpc exchange(ExchangeRequestProto) returns (ExchangeResponseProto);
rpc sleep(SleepRequestProto) returns (EmptyResponseProto); rpc sleep(SleepRequestProto) returns (EmptyResponseProto);
rpc getAuthMethod(EmptyRequestProto) returns (AuthMethodResponseProto);
rpc getAuthUser(EmptyRequestProto) returns (AuthUserResponseProto);
rpc echoPostponed(EchoRequestProto) returns (EchoResponseProto);
rpc sendPostponed(EmptyRequestProto) returns (EmptyResponseProto);
} }
service TestProtobufRpc2Proto { service TestProtobufRpc2Proto {

View File

@ -11,5 +11,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
org.apache.hadoop.ipc.TestSaslRPC$TestTokenIdentifier org.apache.hadoop.ipc.TestRpcBase$TestTokenIdentifier
org.apache.hadoop.security.token.delegation.TestDelegationToken$TestDelegationTokenIdentifier org.apache.hadoop.security.token.delegation.TestDelegationToken$TestDelegationTokenIdentifier