Revert MAPREDUCE-5475 and YARN-707

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1517097 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Jason Darrell Lowe 2013-08-24 01:15:37 +00:00
parent 6f93f20515
commit c660339c09
8 changed files with 46 additions and 194 deletions

View File

@ -243,8 +243,6 @@ Release 2.1.1-beta - UNRELEASED
MAPREDUCE-5476. Changed MR AM recovery code to cleanup staging-directory MAPREDUCE-5476. Changed MR AM recovery code to cleanup staging-directory
only after unregistering from the RM. (Jian He via vinodkv) only after unregistering from the RM. (Jian He via vinodkv)
MAPREDUCE-5475. MRClientService does not verify ACLs properly (jlowe)
Release 2.1.0-beta - 2013-08-22 Release 2.1.0-beta - 2013-08-22
INCOMPATIBLE CHANGES INCOMPATIBLE CHANGES

View File

@ -28,7 +28,6 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.Server; import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.mapreduce.JobACL;
import org.apache.hadoop.mapreduce.MRJobConfig; import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.TypeConverter; import org.apache.hadoop.mapreduce.TypeConverter;
import org.apache.hadoop.mapreduce.v2.api.MRClientProtocol; import org.apache.hadoop.mapreduce.v2.api.MRClientProtocol;
@ -79,8 +78,6 @@ import org.apache.hadoop.mapreduce.v2.app.job.event.TaskEventType;
import org.apache.hadoop.mapreduce.v2.app.security.authorize.MRAMPolicyProvider; import org.apache.hadoop.mapreduce.v2.app.security.authorize.MRAMPolicyProvider;
import org.apache.hadoop.mapreduce.v2.app.webapp.AMWebApp; import org.apache.hadoop.mapreduce.v2.app.webapp.AMWebApp;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.PolicyProvider; import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.service.AbstractService; import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.factories.RecordFactory; import org.apache.hadoop.yarn.factories.RecordFactory;
@ -179,21 +176,15 @@ public class MRClientService extends AbstractService
} }
private Job verifyAndGetJob(JobId jobID, private Job verifyAndGetJob(JobId jobID,
JobACL accessType) throws IOException { boolean modifyAccess) throws IOException {
Job job = appContext.getJob(jobID); Job job = appContext.getJob(jobID);
UserGroupInformation ugi = UserGroupInformation.getCurrentUser();
if (!job.checkAccess(ugi, accessType)) {
throw new AccessControlException("User " + ugi.getShortUserName()
+ " cannot perform operation " + accessType.name() + " on "
+ jobID);
}
return job; return job;
} }
private Task verifyAndGetTask(TaskId taskID, private Task verifyAndGetTask(TaskId taskID,
JobACL accessType) throws IOException { boolean modifyAccess) throws IOException {
Task task = verifyAndGetJob(taskID.getJobId(), Task task = verifyAndGetJob(taskID.getJobId(),
accessType).getTask(taskID); modifyAccess).getTask(taskID);
if (task == null) { if (task == null) {
throw new IOException("Unknown Task " + taskID); throw new IOException("Unknown Task " + taskID);
} }
@ -201,9 +192,9 @@ public class MRClientService extends AbstractService
} }
private TaskAttempt verifyAndGetAttempt(TaskAttemptId attemptID, private TaskAttempt verifyAndGetAttempt(TaskAttemptId attemptID,
JobACL accessType) throws IOException { boolean modifyAccess) throws IOException {
TaskAttempt attempt = verifyAndGetTask(attemptID.getTaskId(), TaskAttempt attempt = verifyAndGetTask(attemptID.getTaskId(),
accessType).getAttempt(attemptID); modifyAccess).getAttempt(attemptID);
if (attempt == null) { if (attempt == null) {
throw new IOException("Unknown TaskAttempt " + attemptID); throw new IOException("Unknown TaskAttempt " + attemptID);
} }
@ -214,7 +205,7 @@ public class MRClientService extends AbstractService
public GetCountersResponse getCounters(GetCountersRequest request) public GetCountersResponse getCounters(GetCountersRequest request)
throws IOException { throws IOException {
JobId jobId = request.getJobId(); JobId jobId = request.getJobId();
Job job = verifyAndGetJob(jobId, JobACL.VIEW_JOB); Job job = verifyAndGetJob(jobId, false);
GetCountersResponse response = GetCountersResponse response =
recordFactory.newRecordInstance(GetCountersResponse.class); recordFactory.newRecordInstance(GetCountersResponse.class);
response.setCounters(TypeConverter.toYarn(job.getAllCounters())); response.setCounters(TypeConverter.toYarn(job.getAllCounters()));
@ -225,7 +216,7 @@ public class MRClientService extends AbstractService
public GetJobReportResponse getJobReport(GetJobReportRequest request) public GetJobReportResponse getJobReport(GetJobReportRequest request)
throws IOException { throws IOException {
JobId jobId = request.getJobId(); JobId jobId = request.getJobId();
Job job = verifyAndGetJob(jobId, JobACL.VIEW_JOB); Job job = verifyAndGetJob(jobId, false);
GetJobReportResponse response = GetJobReportResponse response =
recordFactory.newRecordInstance(GetJobReportResponse.class); recordFactory.newRecordInstance(GetJobReportResponse.class);
if (job != null) { if (job != null) {
@ -244,7 +235,7 @@ public class MRClientService extends AbstractService
GetTaskAttemptReportResponse response = GetTaskAttemptReportResponse response =
recordFactory.newRecordInstance(GetTaskAttemptReportResponse.class); recordFactory.newRecordInstance(GetTaskAttemptReportResponse.class);
response.setTaskAttemptReport( response.setTaskAttemptReport(
verifyAndGetAttempt(taskAttemptId, JobACL.VIEW_JOB).getReport()); verifyAndGetAttempt(taskAttemptId, false).getReport());
return response; return response;
} }
@ -254,8 +245,7 @@ public class MRClientService extends AbstractService
TaskId taskId = request.getTaskId(); TaskId taskId = request.getTaskId();
GetTaskReportResponse response = GetTaskReportResponse response =
recordFactory.newRecordInstance(GetTaskReportResponse.class); recordFactory.newRecordInstance(GetTaskReportResponse.class);
response.setTaskReport( response.setTaskReport(verifyAndGetTask(taskId, false).getReport());
verifyAndGetTask(taskId, JobACL.VIEW_JOB).getReport());
return response; return response;
} }
@ -266,7 +256,7 @@ public class MRClientService extends AbstractService
JobId jobId = request.getJobId(); JobId jobId = request.getJobId();
int fromEventId = request.getFromEventId(); int fromEventId = request.getFromEventId();
int maxEvents = request.getMaxEvents(); int maxEvents = request.getMaxEvents();
Job job = verifyAndGetJob(jobId, JobACL.VIEW_JOB); Job job = verifyAndGetJob(jobId, false);
GetTaskAttemptCompletionEventsResponse response = GetTaskAttemptCompletionEventsResponse response =
recordFactory.newRecordInstance(GetTaskAttemptCompletionEventsResponse.class); recordFactory.newRecordInstance(GetTaskAttemptCompletionEventsResponse.class);
@ -280,11 +270,9 @@ public class MRClientService extends AbstractService
public KillJobResponse killJob(KillJobRequest request) public KillJobResponse killJob(KillJobRequest request)
throws IOException { throws IOException {
JobId jobId = request.getJobId(); JobId jobId = request.getJobId();
UserGroupInformation callerUGI = UserGroupInformation.getCurrentUser(); String message = "Kill Job received from client " + jobId;
String message = "Kill job " + jobId + " received from " + callerUGI
+ " at " + Server.getRemoteAddress();
LOG.info(message); LOG.info(message);
verifyAndGetJob(jobId, JobACL.MODIFY_JOB); verifyAndGetJob(jobId, true);
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
new JobDiagnosticsUpdateEvent(jobId, message)); new JobDiagnosticsUpdateEvent(jobId, message));
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
@ -299,11 +287,9 @@ public class MRClientService extends AbstractService
public KillTaskResponse killTask(KillTaskRequest request) public KillTaskResponse killTask(KillTaskRequest request)
throws IOException { throws IOException {
TaskId taskId = request.getTaskId(); TaskId taskId = request.getTaskId();
UserGroupInformation callerUGI = UserGroupInformation.getCurrentUser(); String message = "Kill task received from client " + taskId;
String message = "Kill task " + taskId + " received from " + callerUGI
+ " at " + Server.getRemoteAddress();
LOG.info(message); LOG.info(message);
verifyAndGetTask(taskId, JobACL.MODIFY_JOB); verifyAndGetTask(taskId, true);
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
new TaskEvent(taskId, TaskEventType.T_KILL)); new TaskEvent(taskId, TaskEventType.T_KILL));
KillTaskResponse response = KillTaskResponse response =
@ -316,12 +302,9 @@ public class MRClientService extends AbstractService
public KillTaskAttemptResponse killTaskAttempt( public KillTaskAttemptResponse killTaskAttempt(
KillTaskAttemptRequest request) throws IOException { KillTaskAttemptRequest request) throws IOException {
TaskAttemptId taskAttemptId = request.getTaskAttemptId(); TaskAttemptId taskAttemptId = request.getTaskAttemptId();
UserGroupInformation callerUGI = UserGroupInformation.getCurrentUser(); String message = "Kill task attempt received from client " + taskAttemptId;
String message = "Kill task attempt " + taskAttemptId
+ " received from " + callerUGI + " at "
+ Server.getRemoteAddress();
LOG.info(message); LOG.info(message);
verifyAndGetAttempt(taskAttemptId, JobACL.MODIFY_JOB); verifyAndGetAttempt(taskAttemptId, true);
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
new TaskAttemptDiagnosticsUpdateEvent(taskAttemptId, message)); new TaskAttemptDiagnosticsUpdateEvent(taskAttemptId, message));
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
@ -339,8 +322,8 @@ public class MRClientService extends AbstractService
GetDiagnosticsResponse response = GetDiagnosticsResponse response =
recordFactory.newRecordInstance(GetDiagnosticsResponse.class); recordFactory.newRecordInstance(GetDiagnosticsResponse.class);
response.addAllDiagnostics(verifyAndGetAttempt(taskAttemptId, response.addAllDiagnostics(
JobACL.VIEW_JOB).getDiagnostics()); verifyAndGetAttempt(taskAttemptId, false).getDiagnostics());
return response; return response;
} }
@ -349,12 +332,9 @@ public class MRClientService extends AbstractService
public FailTaskAttemptResponse failTaskAttempt( public FailTaskAttemptResponse failTaskAttempt(
FailTaskAttemptRequest request) throws IOException { FailTaskAttemptRequest request) throws IOException {
TaskAttemptId taskAttemptId = request.getTaskAttemptId(); TaskAttemptId taskAttemptId = request.getTaskAttemptId();
UserGroupInformation callerUGI = UserGroupInformation.getCurrentUser(); String message = "Fail task attempt received from client " + taskAttemptId;
String message = "Fail task attempt " + taskAttemptId
+ " received from " + callerUGI + " at "
+ Server.getRemoteAddress();
LOG.info(message); LOG.info(message);
verifyAndGetAttempt(taskAttemptId, JobACL.MODIFY_JOB); verifyAndGetAttempt(taskAttemptId, true);
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
new TaskAttemptDiagnosticsUpdateEvent(taskAttemptId, message)); new TaskAttemptDiagnosticsUpdateEvent(taskAttemptId, message));
appContext.getEventHandler().handle( appContext.getEventHandler().handle(
@ -376,7 +356,7 @@ public class MRClientService extends AbstractService
GetTaskReportsResponse response = GetTaskReportsResponse response =
recordFactory.newRecordInstance(GetTaskReportsResponse.class); recordFactory.newRecordInstance(GetTaskReportsResponse.class);
Job job = verifyAndGetJob(jobId, JobACL.VIEW_JOB); Job job = verifyAndGetJob(jobId, false);
Collection<Task> tasks = job.getTasks(taskType).values(); Collection<Task> tasks = job.getTasks(taskType).values();
LOG.info("Getting task report for " + taskType + " " + jobId LOG.info("Getting task report for " + taskType + " " + jobId
+ ". Report-size will be " + tasks.size()); + ". Report-size will be " + tasks.size());

View File

@ -18,20 +18,13 @@
package org.apache.hadoop.mapreduce.v2.app; package org.apache.hadoop.mapreduce.v2.app;
import static org.junit.Assert.fail;
import java.security.PrivilegedExceptionAction;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import junit.framework.Assert; import junit.framework.Assert;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.JobACL;
import org.apache.hadoop.mapreduce.MRConfig;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.v2.api.MRClientProtocol; import org.apache.hadoop.mapreduce.v2.api.MRClientProtocol;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.FailTaskAttemptRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetCountersRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetCountersRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetDiagnosticsRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetDiagnosticsRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetJobReportRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetJobReportRequest;
@ -39,9 +32,6 @@ import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskAttemptCompleti
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskAttemptReportRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskAttemptReportRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskReportRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskReportRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskReportsRequest; import org.apache.hadoop.mapreduce.v2.api.protocolrecords.GetTaskReportsRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.KillJobRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.KillTaskAttemptRequest;
import org.apache.hadoop.mapreduce.v2.api.protocolrecords.KillTaskRequest;
import org.apache.hadoop.mapreduce.v2.api.records.AMInfo; import org.apache.hadoop.mapreduce.v2.api.records.AMInfo;
import org.apache.hadoop.mapreduce.v2.api.records.JobReport; import org.apache.hadoop.mapreduce.v2.api.records.JobReport;
import org.apache.hadoop.mapreduce.v2.api.records.JobState; import org.apache.hadoop.mapreduce.v2.api.records.JobState;
@ -61,8 +51,6 @@ import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEvent;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEventType; import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEventType;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent; import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent.TaskAttemptStatus; import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent.TaskAttemptStatus;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.factories.RecordFactory; import org.apache.hadoop.yarn.factories.RecordFactory;
import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider; import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider;
import org.apache.hadoop.yarn.ipc.YarnRPC; import org.apache.hadoop.yarn.ipc.YarnRPC;
@ -181,79 +169,6 @@ public class TestMRClientService {
app.waitForState(job, JobState.SUCCEEDED); app.waitForState(job, JobState.SUCCEEDED);
} }
@Test
public void testViewAclOnlyCannotModify() throws Exception {
final MRAppWithClientService app = new MRAppWithClientService(1, 0, false);
final Configuration conf = new Configuration();
conf.setBoolean(MRConfig.MR_ACLS_ENABLED, true);
conf.set(MRJobConfig.JOB_ACL_VIEW_JOB, "viewonlyuser");
Job job = app.submit(conf);
app.waitForState(job, JobState.RUNNING);
Assert.assertEquals("Num tasks not correct", 1, job.getTasks().size());
Iterator<Task> it = job.getTasks().values().iterator();
Task task = it.next();
app.waitForState(task, TaskState.RUNNING);
TaskAttempt attempt = task.getAttempts().values().iterator().next();
app.waitForState(attempt, TaskAttemptState.RUNNING);
UserGroupInformation viewOnlyUser =
UserGroupInformation.createUserForTesting(
"viewonlyuser", new String[] {});
Assert.assertTrue("viewonlyuser cannot view job",
job.checkAccess(viewOnlyUser, JobACL.VIEW_JOB));
Assert.assertFalse("viewonlyuser can modify job",
job.checkAccess(viewOnlyUser, JobACL.MODIFY_JOB));
MRClientProtocol client = viewOnlyUser.doAs(
new PrivilegedExceptionAction<MRClientProtocol>() {
@Override
public MRClientProtocol run() throws Exception {
YarnRPC rpc = YarnRPC.create(conf);
return (MRClientProtocol) rpc.getProxy(MRClientProtocol.class,
app.clientService.getBindAddress(), conf);
}
});
KillJobRequest killJobRequest = recordFactory.newRecordInstance(
KillJobRequest.class);
killJobRequest.setJobId(app.getJobId());
try {
client.killJob(killJobRequest);
fail("viewonlyuser killed job");
} catch (AccessControlException e) {
// pass
}
KillTaskRequest killTaskRequest = recordFactory.newRecordInstance(
KillTaskRequest.class);
killTaskRequest.setTaskId(task.getID());
try {
client.killTask(killTaskRequest);
fail("viewonlyuser killed task");
} catch (AccessControlException e) {
// pass
}
KillTaskAttemptRequest killTaskAttemptRequest =
recordFactory.newRecordInstance(KillTaskAttemptRequest.class);
killTaskAttemptRequest.setTaskAttemptId(attempt.getID());
try {
client.killTaskAttempt(killTaskAttemptRequest);
fail("viewonlyuser killed task attempt");
} catch (AccessControlException e) {
// pass
}
FailTaskAttemptRequest failTaskAttemptRequest =
recordFactory.newRecordInstance(FailTaskAttemptRequest.class);
failTaskAttemptRequest.setTaskAttemptId(attempt.getID());
try {
client.failTaskAttempt(failTaskAttemptRequest);
fail("viewonlyuser killed task attempt");
} catch (AccessControlException e) {
// pass
}
}
private void verifyJobReport(JobReport jr) { private void verifyJobReport(JobReport jr) {
Assert.assertNotNull("JobReport is null", jr); Assert.assertNotNull("JobReport is null", jr);
List<AMInfo> amInfos = jr.getAMInfos(); List<AMInfo> amInfos = jr.getAMInfos();

View File

@ -45,8 +45,6 @@ Release 2.1.1-beta - UNRELEASED
YARN-589. Expose a REST API for monitoring the fair scheduler (Sandy Ryza). YARN-589. Expose a REST API for monitoring the fair scheduler (Sandy Ryza).
YARN-707. Add user info in the YARN ClientToken (vinodkv via jlowe)
OPTIMIZATIONS OPTIMIZATIONS
BUG FIXES BUG FIXES

View File

@ -39,7 +39,6 @@ public class ClientToAMTokenIdentifier extends TokenIdentifier {
public static final Text KIND_NAME = new Text("YARN_CLIENT_TOKEN"); public static final Text KIND_NAME = new Text("YARN_CLIENT_TOKEN");
private ApplicationAttemptId applicationAttemptId; private ApplicationAttemptId applicationAttemptId;
private Text applicationSubmitter = new Text();
// TODO: Add more information in the tokenID such that it is not // TODO: Add more information in the tokenID such that it is not
// transferrable, more secure etc. // transferrable, more secure etc.
@ -47,27 +46,21 @@ public class ClientToAMTokenIdentifier extends TokenIdentifier {
public ClientToAMTokenIdentifier() { public ClientToAMTokenIdentifier() {
} }
public ClientToAMTokenIdentifier(ApplicationAttemptId id, String appSubmitter) { public ClientToAMTokenIdentifier(ApplicationAttemptId id) {
this(); this();
this.applicationAttemptId = id; this.applicationAttemptId = id;
this.applicationSubmitter = new Text(appSubmitter);
} }
public ApplicationAttemptId getApplicationAttemptID() { public ApplicationAttemptId getApplicationAttemptID() {
return this.applicationAttemptId; return this.applicationAttemptId;
} }
public String getApplicationSubmitter() {
return this.applicationSubmitter.toString();
}
@Override @Override
public void write(DataOutput out) throws IOException { public void write(DataOutput out) throws IOException {
out.writeLong(this.applicationAttemptId.getApplicationId() out.writeLong(this.applicationAttemptId.getApplicationId()
.getClusterTimestamp()); .getClusterTimestamp());
out.writeInt(this.applicationAttemptId.getApplicationId().getId()); out.writeInt(this.applicationAttemptId.getApplicationId().getId());
out.writeInt(this.applicationAttemptId.getAttemptId()); out.writeInt(this.applicationAttemptId.getAttemptId());
this.applicationSubmitter.write(out);
} }
@Override @Override
@ -75,7 +68,6 @@ public class ClientToAMTokenIdentifier extends TokenIdentifier {
this.applicationAttemptId = this.applicationAttemptId =
ApplicationAttemptId.newInstance( ApplicationAttemptId.newInstance(
ApplicationId.newInstance(in.readLong(), in.readInt()), in.readInt()); ApplicationId.newInstance(in.readLong(), in.readInt()), in.readInt());
this.applicationSubmitter.readFields(in);
} }
@Override @Override
@ -85,11 +77,10 @@ public class ClientToAMTokenIdentifier extends TokenIdentifier {
@Override @Override
public UserGroupInformation getUser() { public UserGroupInformation getUser() {
if (this.applicationSubmitter == null) { if (this.applicationAttemptId == null) {
return null; return null;
} }
return UserGroupInformation.createRemoteUser(this.applicationSubmitter return UserGroupInformation.createRemoteUser(this.applicationAttemptId.toString());
.toString());
} }
@InterfaceAudience.Private @InterfaceAudience.Private

View File

@ -722,7 +722,7 @@ public class RMAppAttemptImpl implements RMAppAttempt, Recoverable {
// create clientToAMToken // create clientToAMToken
appAttempt.clientToAMToken = appAttempt.clientToAMToken =
new Token<ClientToAMTokenIdentifier>(new ClientToAMTokenIdentifier( new Token<ClientToAMTokenIdentifier>(new ClientToAMTokenIdentifier(
appAttempt.applicationAttemptId, appAttempt.user), appAttempt.applicationAttemptId),
appAttempt.rmContext.getClientToAMTokenSecretManager()); appAttempt.rmContext.getClientToAMTokenSecretManager());
} }

View File

@ -367,7 +367,7 @@ public class TestRMStateStore {
appToken.setService(new Text("appToken service")); appToken.setService(new Text("appToken service"));
ClientToAMTokenIdentifier clientToAMTokenId = ClientToAMTokenIdentifier clientToAMTokenId =
new ClientToAMTokenIdentifier(attemptId, "user"); new ClientToAMTokenIdentifier(attemptId);
clientToAMTokenMgr.registerApplication(attemptId); clientToAMTokenMgr.registerApplication(attemptId);
Token<ClientToAMTokenIdentifier> clientToAMToken = Token<ClientToAMTokenIdentifier> clientToAMToken =
new Token<ClientToAMTokenIdentifier>(clientToAMTokenId, clientToAMTokenMgr); new Token<ClientToAMTokenIdentifier>(clientToAMTokenId, clientToAMTokenMgr);

View File

@ -115,6 +115,7 @@ public class TestClientToAMTokens {
private final byte[] secretKey; private final byte[] secretKey;
private InetSocketAddress address; private InetSocketAddress address;
private boolean pinged = false; private boolean pinged = false;
private ClientToAMTokenSecretManager secretManager;
public CustomAM(ApplicationAttemptId appId, byte[] secretKey) { public CustomAM(ApplicationAttemptId appId, byte[] secretKey) {
super("CustomAM"); super("CustomAM");
@ -131,14 +132,12 @@ public class TestClientToAMTokens {
protected void serviceStart() throws Exception { protected void serviceStart() throws Exception {
Configuration conf = getConfig(); Configuration conf = getConfig();
secretManager = new ClientToAMTokenSecretManager(this.appAttemptId, secretKey);
Server server; Server server;
try { try {
server = server =
new RPC.Builder(conf) new RPC.Builder(conf).setProtocol(CustomProtocol.class)
.setProtocol(CustomProtocol.class) .setNumHandlers(1).setSecretManager(secretManager)
.setNumHandlers(1)
.setSecretManager(
new ClientToAMTokenSecretManager(this.appAttemptId, secretKey))
.setInstance(this).build(); .setInstance(this).build();
} catch (Exception e) { } catch (Exception e) {
throw new YarnRuntimeException(e); throw new YarnRuntimeException(e);
@ -147,10 +146,14 @@ public class TestClientToAMTokens {
this.address = NetUtils.getConnectAddress(server); this.address = NetUtils.getConnectAddress(server);
super.serviceStart(); super.serviceStart();
} }
public ClientToAMTokenSecretManager getClientToAMTokenSecretManager() {
return this.secretManager;
}
} }
@Test @Test
public void testClientToAMTokenss() throws Exception { public void testClientToAMs() throws Exception {
final Configuration conf = new Configuration(); final Configuration conf = new Configuration();
conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION, conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION,
@ -201,7 +204,7 @@ public class TestClientToAMTokens {
GetApplicationReportResponse reportResponse = GetApplicationReportResponse reportResponse =
rm.getClientRMService().getApplicationReport(request); rm.getClientRMService().getApplicationReport(request);
ApplicationReport appReport = reportResponse.getApplicationReport(); ApplicationReport appReport = reportResponse.getApplicationReport();
org.apache.hadoop.yarn.api.records.Token originalClientToAMToken = org.apache.hadoop.yarn.api.records.Token clientToAMToken =
appReport.getClientToAMToken(); appReport.getClientToAMToken();
ApplicationAttemptId appAttempt = app.getCurrentAppAttempt().getAppAttemptId(); ApplicationAttemptId appAttempt = app.getCurrentAppAttempt().getAppAttemptId();
@ -256,47 +259,17 @@ public class TestClientToAMTokens {
Assert.assertFalse(am.pinged); Assert.assertFalse(am.pinged);
} }
Token<ClientToAMTokenIdentifier> token = // Verify denial for a malicious user
ConverterUtils.convertFromYarn(originalClientToAMToken, am.address);
// Verify denial for a malicious user with tampered ID
verifyTokenWithTamperedID(conf, am, token);
// Verify denial for a malicious user with tampered user-name
verifyTokenWithTamperedUserName(conf, am, token);
// Now for an authenticated user
verifyValidToken(conf, am, token);
}
private void verifyTokenWithTamperedID(final Configuration conf,
final CustomAM am, Token<ClientToAMTokenIdentifier> token)
throws IOException {
// Malicious user, messes with appId
UserGroupInformation ugi = UserGroupInformation.createRemoteUser("me"); UserGroupInformation ugi = UserGroupInformation.createRemoteUser("me");
Token<ClientToAMTokenIdentifier> token =
ConverterUtils.convertFromYarn(clientToAMToken, am.address);
// Malicious user, messes with appId
ClientToAMTokenIdentifier maliciousID = ClientToAMTokenIdentifier maliciousID =
new ClientToAMTokenIdentifier(BuilderUtils.newApplicationAttemptId( new ClientToAMTokenIdentifier(BuilderUtils.newApplicationAttemptId(
BuilderUtils.newApplicationId(am.appAttemptId.getApplicationId() BuilderUtils.newApplicationId(app.getApplicationId()
.getClusterTimestamp(), 42), 43), UserGroupInformation .getClusterTimestamp(), 42), 43));
.getCurrentUser().getShortUserName());
verifyTamperedToken(conf, am, token, ugi, maliciousID);
}
private void verifyTokenWithTamperedUserName(final Configuration conf,
final CustomAM am, Token<ClientToAMTokenIdentifier> token)
throws IOException {
// Malicious user, messes with appId
UserGroupInformation ugi = UserGroupInformation.createRemoteUser("me");
ClientToAMTokenIdentifier maliciousID =
new ClientToAMTokenIdentifier(am.appAttemptId, "evilOrc");
verifyTamperedToken(conf, am, token, ugi, maliciousID);
}
private void verifyTamperedToken(final Configuration conf, final CustomAM am,
Token<ClientToAMTokenIdentifier> token, UserGroupInformation ugi,
ClientToAMTokenIdentifier maliciousID) {
Token<ClientToAMTokenIdentifier> maliciousToken = Token<ClientToAMTokenIdentifier> maliciousToken =
new Token<ClientToAMTokenIdentifier>(maliciousID.getBytes(), new Token<ClientToAMTokenIdentifier>(maliciousID.getBytes(),
token.getPassword(), token.getKind(), token.getPassword(), token.getKind(),
@ -336,12 +309,8 @@ public class TestClientToAMTokens {
+ "Mismatched response.")); + "Mismatched response."));
Assert.assertFalse(am.pinged); Assert.assertFalse(am.pinged);
} }
}
private void verifyValidToken(final Configuration conf, final CustomAM am, // Now for an authenticated user
Token<ClientToAMTokenIdentifier> token) throws IOException,
InterruptedException {
UserGroupInformation ugi;
ugi = UserGroupInformation.createRemoteUser("me"); ugi = UserGroupInformation.createRemoteUser("me");
ugi.addToken(token); ugi.addToken(token);
@ -357,4 +326,5 @@ public class TestClientToAMTokens {
} }
}); });
} }
} }