YARN-8310. Handle old NMTokenIdentifier, AMRMTokenIdentifier, and ContainerTokenIdentifier formats. Contributed by Robert Kanter.

(cherry picked from commit 3e5f7ea986)
This commit is contained in:
Miklos Szegedi 2018-05-22 18:10:33 -07:00
parent f49e697cc8
commit 1483c90379
5 changed files with 278 additions and 26 deletions

View File

@ -513,4 +513,24 @@ public class IOUtils {
throw exception;
}
}
/**
* Reads a DataInput until EOF and returns a byte array. Make sure not to
* pass in an infinite DataInput or this will never return.
*
* @param in A DataInput
* @return a byte array containing the data from the DataInput
* @throws IOException on I/O error, other than EOF
*/
public static byte[] readFullyToByteArray(DataInput in) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
while (true) {
baos.write(in.readByte());
}
} catch (EOFException eof) {
// finished reading, do nothing
}
return baos.toByteArray();
}
}

View File

@ -18,20 +18,26 @@
package org.apache.hadoop.yarn.security;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Evolving;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.impl.pb.ApplicationAttemptIdPBImpl;
import org.apache.hadoop.yarn.proto.YarnSecurityTokenProtos.AMRMTokenIdentifierProto;
@ -45,6 +51,8 @@ import com.google.protobuf.TextFormat;
@Evolving
public class AMRMTokenIdentifier extends TokenIdentifier {
private static final Log LOG = LogFactory.getLog(AMRMTokenIdentifier.class);
public static final Text KIND_NAME = new Text("YARN_AM_RM_TOKEN");
private AMRMTokenIdentifierProto proto;
@ -78,7 +86,30 @@ public class AMRMTokenIdentifier extends TokenIdentifier {
@Override
public void readFields(DataInput in) throws IOException {
proto = AMRMTokenIdentifierProto.parseFrom((DataInputStream)in);
byte[] data = IOUtils.readFullyToByteArray(in);
try {
proto = AMRMTokenIdentifierProto.parseFrom(data);
} catch (InvalidProtocolBufferException e) {
LOG.warn("Recovering old formatted token");
readFieldsInOldFormat(
new DataInputStream(new ByteArrayInputStream(data)));
}
}
private void readFieldsInOldFormat(DataInputStream in) throws IOException {
AMRMTokenIdentifierProto.Builder builder =
AMRMTokenIdentifierProto.newBuilder();
long clusterTimeStamp = in.readLong();
int appId = in.readInt();
int attemptId = in.readInt();
ApplicationId applicationId =
ApplicationId.newInstance(clusterTimeStamp, appId);
ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(applicationId, attemptId);
builder.setAppAttemptId(
((ApplicationAttemptIdPBImpl)appAttemptId).getProto());
builder.setKeyId(in.readInt());
proto = builder.build();
}
@Override

View File

@ -1,40 +1,45 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.security;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.EOFException;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Evolving;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ExecutionType;
import org.apache.hadoop.yarn.api.records.LogAggregationContext;
@ -48,6 +53,7 @@ import org.apache.hadoop.yarn.api.records.impl.pb.ResourcePBImpl;
import org.apache.hadoop.yarn.nodelabels.CommonNodeLabelsManager;
import org.apache.hadoop.yarn.proto.YarnProtos.ContainerTypeProto;
import org.apache.hadoop.yarn.proto.YarnProtos.ExecutionTypeProto;
import org.apache.hadoop.yarn.proto.YarnProtos.LogAggregationContextProto;
import org.apache.hadoop.yarn.proto.YarnSecurityTokenProtos.ContainerTokenIdentifierProto;
import org.apache.hadoop.yarn.server.api.ContainerType;
@ -325,7 +331,63 @@ public class ContainerTokenIdentifier extends TokenIdentifier {
@Override
public void readFields(DataInput in) throws IOException {
proto = ContainerTokenIdentifierProto.parseFrom((DataInputStream)in);
byte[] data = IOUtils.readFullyToByteArray(in);
try {
proto = ContainerTokenIdentifierProto.parseFrom(data);
} catch (InvalidProtocolBufferException e) {
LOG.warn("Recovering old formatted token");
readFieldsInOldFormat(
new DataInputStream(new ByteArrayInputStream(data)));
}
}
private void readFieldsInOldFormat(DataInputStream in) throws IOException {
ContainerTokenIdentifierProto.Builder builder =
ContainerTokenIdentifierProto.newBuilder();
builder.setNodeLabelExpression(CommonNodeLabelsManager.NO_LABEL);
builder.setContainerType(ProtoUtils.convertToProtoFormat(
ContainerType.TASK));
builder.setExecutionType(ProtoUtils.convertToProtoFormat(
ExecutionType.GUARANTEED));
builder.setAllocationRequestId(-1);
builder.setVersion(0);
ApplicationId applicationId =
ApplicationId.newInstance(in.readLong(), in.readInt());
ApplicationAttemptId applicationAttemptId =
ApplicationAttemptId.newInstance(applicationId, in.readInt());
ContainerId containerId =
ContainerId.newContainerId(applicationAttemptId, in.readLong());
builder.setContainerId(ProtoUtils.convertToProtoFormat(containerId));
builder.setNmHostAddr(in.readUTF());
builder.setAppSubmitter(in.readUTF());
int memory = in.readInt();
int vCores = in.readInt();
Resource resource = Resource.newInstance(memory, vCores);
builder.setResource(ProtoUtils.convertToProtoFormat(resource));
builder.setExpiryTimeStamp(in.readLong());
builder.setMasterKeyId(in.readInt());
builder.setRmIdentifier(in.readLong());
Priority priority = Priority.newInstance(in.readInt());
builder.setPriority(((PriorityPBImpl)priority).getProto());
builder.setCreationTime(in.readLong());
int logAggregationSize = -1;
try {
logAggregationSize = in.readInt();
} catch (EOFException eof) {
// In the old format, there was no versioning or proper handling of new
// fields. Depending on how old, the log aggregation size and data, may
// or may not exist. To handle that, we try to read it and ignore the
// EOFException that's thrown if it's not there.
}
if (logAggregationSize != -1) {
byte[] bytes = new byte[logAggregationSize];
in.readFully(bytes);
builder.setLogAggregationContext(
LogAggregationContextProto.parseFrom(bytes));
}
proto = builder.build();
}
@Override

View File

@ -18,19 +18,23 @@
package org.apache.hadoop.yarn.security;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Evolving;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.impl.pb.ApplicationAttemptIdPBImpl;
import org.apache.hadoop.yarn.api.records.impl.pb.NodeIdPBImpl;
@ -99,7 +103,33 @@ public class NMTokenIdentifier extends TokenIdentifier {
@Override
public void readFields(DataInput in) throws IOException {
proto = NMTokenIdentifierProto.parseFrom((DataInputStream)in);
byte[] data = IOUtils.readFullyToByteArray(in);
try {
proto = NMTokenIdentifierProto.parseFrom(data);
} catch (InvalidProtocolBufferException e) {
LOG.warn("Recovering old formatted token");
readFieldsInOldFormat(
new DataInputStream(new ByteArrayInputStream(data)));
}
}
private void readFieldsInOldFormat(DataInputStream in) throws IOException {
NMTokenIdentifierProto.Builder builder =
NMTokenIdentifierProto.newBuilder();
ApplicationAttemptId appAttemptId =
ApplicationAttemptId.newInstance(
ApplicationId.newInstance(in.readLong(), in.readInt()),
in.readInt());
builder.setAppAttemptId(((ApplicationAttemptIdPBImpl)appAttemptId)
.getProto());
String[] hostAddr = in.readUTF().split(":");
NodeId nodeId = NodeId.newInstance(hostAddr[0],
Integer.parseInt(hostAddr[1]));
builder.setNodeId(((NodeIdPBImpl)nodeId).getProto());
builder.setAppSubmitter(in.readUTF());
builder.setKeyId(in.readInt());
proto = builder.build();
}
@Override

View File

@ -34,6 +34,7 @@ import org.apache.hadoop.yarn.api.records.ExecutionType;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.impl.pb.LogAggregationContextPBImpl;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.nodelabels.CommonNodeLabelsManager;
import org.apache.hadoop.yarn.proto.YarnSecurityTokenProtos.ContainerTokenIdentifierProto;
@ -49,6 +50,15 @@ public class TestYARNTokenIdentifier {
@Test
public void testNMTokenIdentifier() throws IOException {
testNMTokenIdentifier(false);
}
@Test
public void testNMTokenIdentifierOldFormat() throws IOException {
testNMTokenIdentifier(true);
}
public void testNMTokenIdentifier(boolean oldFormat) throws IOException {
ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(
ApplicationId.newInstance(1, 1), 1);
NodeId nodeId = NodeId.newInstance("host0", 0);
@ -59,8 +69,13 @@ public class TestYARNTokenIdentifier {
appAttemptId, nodeId, applicationSubmitter, masterKeyId);
NMTokenIdentifier anotherToken = new NMTokenIdentifier();
byte[] tokenContent = token.getBytes();
byte[] tokenContent;
if (oldFormat) {
tokenContent = writeInOldFormat(token);
} else {
tokenContent = token.getBytes();
}
DataInputBuffer dib = new DataInputBuffer();
dib.reset(tokenContent, tokenContent.length);
anotherToken.readFields(dib);
@ -89,6 +104,15 @@ public class TestYARNTokenIdentifier {
@Test
public void testAMRMTokenIdentifier() throws IOException {
testAMRMTokenIdentifier(false);
}
@Test
public void testAMRMTokenIdentifierOldFormat() throws IOException {
testAMRMTokenIdentifier(true);
}
public void testAMRMTokenIdentifier(boolean oldFormat) throws IOException {
ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(
ApplicationId.newInstance(1, 1), 1);
int masterKeyId = 1;
@ -96,7 +120,13 @@ public class TestYARNTokenIdentifier {
AMRMTokenIdentifier token = new AMRMTokenIdentifier(appAttemptId, masterKeyId);
AMRMTokenIdentifier anotherToken = new AMRMTokenIdentifier();
byte[] tokenContent = token.getBytes();
byte[] tokenContent;
if (oldFormat) {
tokenContent = writeInOldFormat(token);
} else {
tokenContent = token.getBytes();
}
DataInputBuffer dib = new DataInputBuffer();
dib.reset(tokenContent, tokenContent.length);
anotherToken.readFields(dib);
@ -139,7 +169,7 @@ public class TestYARNTokenIdentifier {
Assert.assertEquals("clientName from proto is not the same with original token",
anotherToken.getClientName(), clientName);
}
@Test
public void testContainerTokenIdentifierProtoMissingFields()
throws IOException {
@ -166,6 +196,17 @@ public class TestYARNTokenIdentifier {
@Test
public void testContainerTokenIdentifier() throws IOException {
testContainerTokenIdentifier(false, false);
}
@Test
public void testContainerTokenIdentifierOldFormat() throws IOException {
testContainerTokenIdentifier(true, true);
testContainerTokenIdentifier(true, false);
}
public void testContainerTokenIdentifier(boolean oldFormat,
boolean withLogAggregation) throws IOException {
ContainerId containerID = ContainerId.newContainerId(
ApplicationAttemptId.newInstance(ApplicationId.newInstance(
1, 1), 1), 1);
@ -183,8 +224,13 @@ public class TestYARNTokenIdentifier {
masterKeyId, rmIdentifier, priority, creationTime);
ContainerTokenIdentifier anotherToken = new ContainerTokenIdentifier();
byte[] tokenContent = token.getBytes();
byte[] tokenContent;
if (oldFormat) {
tokenContent = writeInOldFormat(token, withLogAggregation);
} else {
tokenContent = token.getBytes();
}
DataInputBuffer dib = new DataInputBuffer();
dib.reset(tokenContent, tokenContent.length);
anotherToken.readFields(dib);
@ -451,4 +497,67 @@ public class TestYARNTokenIdentifier {
anotherToken.getExecutionType());
}
@SuppressWarnings("deprecation")
private byte[] writeInOldFormat(ContainerTokenIdentifier token,
boolean withLogAggregation) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream out = new DataOutputStream(baos);
ApplicationAttemptId applicationAttemptId = token.getContainerID()
.getApplicationAttemptId();
ApplicationId applicationId = applicationAttemptId.getApplicationId();
out.writeLong(applicationId.getClusterTimestamp());
out.writeInt(applicationId.getId());
out.writeInt(applicationAttemptId.getAttemptId());
out.writeLong(token.getContainerID().getContainerId());
out.writeUTF(token.getNmHostAddress());
out.writeUTF(token.getApplicationSubmitter());
out.writeInt(token.getResource().getMemory());
out.writeInt(token.getResource().getVirtualCores());
out.writeLong(token.getExpiryTimeStamp());
out.writeInt(token.getMasterKeyId());
out.writeLong(token.getRMIdentifier());
out.writeInt(token.getPriority().getPriority());
out.writeLong(token.getCreationTime());
if (withLogAggregation) {
if (token.getLogAggregationContext() == null) {
out.writeInt(-1);
} else {
byte[] logAggregationContext = ((LogAggregationContextPBImpl)
token.getLogAggregationContext()).getProto().toByteArray();
out.writeInt(logAggregationContext.length);
out.write(logAggregationContext);
}
}
out.close();
return baos.toByteArray();
}
private byte[] writeInOldFormat(NMTokenIdentifier token) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream out = new DataOutputStream(baos);
ApplicationId applicationId = token.getApplicationAttemptId()
.getApplicationId();
out.writeLong(applicationId.getClusterTimestamp());
out.writeInt(applicationId.getId());
out.writeInt(token.getApplicationAttemptId().getAttemptId());
out.writeUTF(token.getNodeId().toString());
out.writeUTF(token.getApplicationSubmitter());
out.writeInt(token.getKeyId());
out.close();
return baos.toByteArray();
}
private byte[] writeInOldFormat(AMRMTokenIdentifier token)
throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream out = new DataOutputStream(baos);
ApplicationId applicationId = token.getApplicationAttemptId()
.getApplicationId();
out.writeLong(applicationId.getClusterTimestamp());
out.writeInt(applicationId.getId());
out.writeInt(token.getApplicationAttemptId().getAttemptId());
out.writeInt(token.getKeyId());
out.close();
return baos.toByteArray();
}
}