HDFS-16332. Handle invalid token exception in sasl handshake (#3677)

Signed-off-by: Akira Ajisaka <aajisaka@apache.org>
This commit is contained in:
bitterfox 2021-12-03 23:30:13 +09:00 committed by GitHub
parent 0cb6c28d19
commit dd6b987c93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 259 additions and 36 deletions

View File

@ -34,6 +34,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import javax.security.sasl.Sasl;
import org.apache.commons.codec.binary.Base64;
@ -52,6 +53,7 @@
import org.apache.hadoop.hdfs.protocol.proto.DataTransferProtos.HandshakeSecretProto;
import org.apache.hadoop.hdfs.protocol.proto.HdfsProtos.CipherOptionProto;
import org.apache.hadoop.hdfs.protocolPB.PBHelperClient;
import org.apache.hadoop.hdfs.security.token.block.InvalidBlockTokenException;
import org.apache.hadoop.security.SaslPropertiesResolver;
import org.apache.hadoop.security.SaslRpcServer.QualityOfProtection;
import org.slf4j.Logger;
@ -204,6 +206,26 @@ public static SaslPropertiesResolver getSaslPropertiesResolver(
return resolver;
}
private static <T> T readSaslMessage(InputStream in,
Function<DataTransferEncryptorMessageProto, ? extends T> handler) throws IOException {
DataTransferEncryptorMessageProto proto =
DataTransferEncryptorMessageProto.parseFrom(vintPrefixed(in));
switch (proto.getStatus()) {
case ERROR_UNKNOWN_KEY:
throw new InvalidEncryptionKeyException(proto.getMessage());
case ERROR:
if (proto.hasAccessTokenError() && proto.getAccessTokenError()) {
throw new InvalidBlockTokenException(proto.getMessage());
}
throw new IOException(proto.getMessage());
case SUCCESS:
return handler.apply(proto);
default:
throw new IOException(
"Unknown status: " + proto.getStatus() + ", message: " + proto.getMessage());
}
}
/**
* Reads a SASL negotiation message.
*
@ -212,15 +234,7 @@ public static SaslPropertiesResolver getSaslPropertiesResolver(
* @throws IOException for any error
*/
public static byte[] readSaslMessage(InputStream in) throws IOException {
DataTransferEncryptorMessageProto proto =
DataTransferEncryptorMessageProto.parseFrom(vintPrefixed(in));
if (proto.getStatus() == DataTransferEncryptorStatus.ERROR_UNKNOWN_KEY) {
throw new InvalidEncryptionKeyException(proto.getMessage());
} else if (proto.getStatus() == DataTransferEncryptorStatus.ERROR) {
throw new IOException(proto.getMessage());
} else {
return proto.getPayload().toByteArray();
}
return readSaslMessage(in, proto -> proto.getPayload().toByteArray());
}
/**
@ -233,13 +247,7 @@ public static byte[] readSaslMessage(InputStream in) throws IOException {
*/
public static byte[] readSaslMessageAndNegotiationCipherOptions(
InputStream in, List<CipherOption> cipherOptions) throws IOException {
DataTransferEncryptorMessageProto proto =
DataTransferEncryptorMessageProto.parseFrom(vintPrefixed(in));
if (proto.getStatus() == DataTransferEncryptorStatus.ERROR_UNKNOWN_KEY) {
throw new InvalidEncryptionKeyException(proto.getMessage());
} else if (proto.getStatus() == DataTransferEncryptorStatus.ERROR) {
throw new IOException(proto.getMessage());
} else {
return readSaslMessage(in, proto -> {
List<CipherOptionProto> optionProtos = proto.getCipherOptionList();
if (optionProtos != null) {
for (CipherOptionProto optionProto : optionProtos) {
@ -247,7 +255,7 @@ public static byte[] readSaslMessageAndNegotiationCipherOptions(
}
}
return proto.getPayload().toByteArray();
}
});
}
static class SaslMessageWithHandshake {
@ -276,13 +284,7 @@ String getBpid() {
public static SaslMessageWithHandshake readSaslMessageWithHandshakeSecret(
InputStream in) throws IOException {
DataTransferEncryptorMessageProto proto =
DataTransferEncryptorMessageProto.parseFrom(vintPrefixed(in));
if (proto.getStatus() == DataTransferEncryptorStatus.ERROR_UNKNOWN_KEY) {
throw new InvalidEncryptionKeyException(proto.getMessage());
} else if (proto.getStatus() == DataTransferEncryptorStatus.ERROR) {
throw new IOException(proto.getMessage());
} else {
return readSaslMessage(in, proto -> {
byte[] payload = proto.getPayload().toByteArray();
byte[] secret = null;
String bpid = null;
@ -292,7 +294,7 @@ public static SaslMessageWithHandshake readSaslMessageWithHandshakeSecret(
bpid = handshakeSecret.getBpid();
}
return new SaslMessageWithHandshake(payload, secret, bpid);
}
});
}
/**
@ -467,13 +469,7 @@ public static void sendSaslMessageAndNegotiationCipherOptions(
public static SaslResponseWithNegotiatedCipherOption
readSaslMessageAndNegotiatedCipherOption(InputStream in)
throws IOException {
DataTransferEncryptorMessageProto proto =
DataTransferEncryptorMessageProto.parseFrom(vintPrefixed(in));
if (proto.getStatus() == DataTransferEncryptorStatus.ERROR_UNKNOWN_KEY) {
throw new InvalidEncryptionKeyException(proto.getMessage());
} else if (proto.getStatus() == DataTransferEncryptorStatus.ERROR) {
throw new IOException(proto.getMessage());
} else {
return readSaslMessage(in, proto -> {
byte[] response = proto.getPayload().toByteArray();
List<CipherOption> options = PBHelperClient.convertCipherOptionProtos(
proto.getCipherOptionList());
@ -482,7 +478,7 @@ public static void sendSaslMessageAndNegotiationCipherOptions(
option = options.get(0);
}
return new SaslResponseWithNegotiatedCipherOption(response, option);
}
});
}
/**
@ -558,6 +554,13 @@ public static void sendSaslMessage(OutputStream out,
DataTransferEncryptorStatus status, byte[] payload, String message,
HandshakeSecretProto handshakeSecret)
throws IOException {
sendSaslMessage(out, status, payload, message, handshakeSecret, false);
}
public static void sendSaslMessage(OutputStream out,
DataTransferEncryptorStatus status, byte[] payload, String message,
HandshakeSecretProto handshakeSecret, boolean accessTokenError)
throws IOException {
DataTransferEncryptorMessageProto.Builder builder =
DataTransferEncryptorMessageProto.newBuilder();
@ -571,6 +574,9 @@ public static void sendSaslMessage(OutputStream out,
if (handshakeSecret != null) {
builder.setHandshakeSecret(handshakeSecret);
}
if (accessTokenError) {
builder.setAccessTokenError(true);
}
DataTransferEncryptorMessageProto proto = builder.build();
proto.writeDelimitedTo(out);

View File

@ -588,11 +588,11 @@ private IOStreamPair doSaslHandshake(InetAddress addr,
// the client accepts some cipher suites, but the server does not.
LOG.debug("Client accepts cipher suites {}, "
+ "but server {} does not accept any of them",
cipherSuites, addr.toString());
cipherSuites, addr);
}
} else {
LOG.debug("Client using cipher suite {} with server {}",
cipherOption.getCipherSuite().getName(), addr.toString());
cipherOption.getCipherSuite().getName(), addr);
}
}
}
@ -603,7 +603,20 @@ private IOStreamPair doSaslHandshake(InetAddress addr,
conf, cipherOption, underlyingOut, underlyingIn, false) :
sasl.createStreamPair(out, in);
} catch (IOException ioe) {
sendGenericSaslErrorMessage(out, ioe.getMessage());
String message = ioe.getMessage();
try {
sendGenericSaslErrorMessage(out, message);
} catch (Exception e) {
// If ioe is caused by error response from server, server will close peer connection.
// So sendGenericSaslErrorMessage might cause IOException due to "Broken pipe".
// We suppress IOException from sendGenericSaslErrorMessage
// and always throw `ioe` as top level.
// `ioe` can be InvalidEncryptionKeyException or InvalidBlockTokenException
// that indicates refresh key or token and are important for caller.
LOG.debug("Failed to send generic sasl error to server {} (message: {}), "
+ "suppress exception", addr, message, e);
ioe.addSuppressed(e);
}
throw ioe;
}
}

View File

@ -44,6 +44,7 @@ message DataTransferEncryptorMessageProto {
optional string message = 3;
repeated CipherOptionProto cipherOption = 4;
optional HandshakeSecretProto handshakeSecret = 5;
optional bool accessTokenError = 6;
}
message HandshakeSecretProto {

View File

@ -52,10 +52,12 @@
import org.apache.hadoop.hdfs.protocol.proto.DataTransferProtos.DataTransferEncryptorMessageProto.DataTransferEncryptorStatus;
import org.apache.hadoop.hdfs.security.token.block.BlockPoolTokenSecretManager;
import org.apache.hadoop.hdfs.security.token.block.BlockTokenIdentifier;
import org.apache.hadoop.hdfs.security.token.block.InvalidBlockTokenException;
import org.apache.hadoop.hdfs.server.datanode.DNConf;
import org.apache.hadoop.security.SaslPropertiesResolver;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.util.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -441,6 +443,14 @@ private IOStreamPair doSaslHandshake(Peer peer, OutputStream underlyingOut,
// error, the client will get a new encryption key from the NN and retry
// connecting to this DN.
sendInvalidKeySaslErrorMessage(out, ioe.getCause().getMessage());
} else if (ioe instanceof SaslException &&
ioe.getCause() != null &&
(ioe.getCause() instanceof InvalidBlockTokenException ||
ioe.getCause() instanceof SecretManager.InvalidToken)) {
// This could be because the client is long-lived and block token is expired
// The client will get new block token from the NN, upon receiving this error
// and retry connecting to this DN
sendInvalidTokenSaslErrorMessage(out, ioe.getCause().getMessage());
} else {
sendGenericSaslErrorMessage(out, ioe.getMessage());
}
@ -460,4 +470,16 @@ private static void sendInvalidKeySaslErrorMessage(DataOutputStream out,
sendSaslMessage(out, DataTransferEncryptorStatus.ERROR_UNKNOWN_KEY, null,
message);
}
/**
* Sends a SASL negotiation message indicating an invalid token error.
*
* @param out stream to receive message
* @param message to send
* @throws IOException for any error
*/
private static void sendInvalidTokenSaslErrorMessage(DataOutputStream out,
String message) throws IOException {
sendSaslMessage(out, DataTransferEncryptorStatus.ERROR, null, message, null, true);
}
}

View File

@ -0,0 +1,181 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hdfs.protocol.datatransfer.sasl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hdfs.DFSInputStream;
import org.apache.hadoop.hdfs.HdfsConfiguration;
import org.apache.hadoop.hdfs.MiniDFSCluster;
import org.apache.hadoop.hdfs.client.HdfsClientConfigKeys.HedgedRead;
import org.apache.hadoop.hdfs.client.HdfsClientConfigKeys.Retry;
import org.apache.hadoop.hdfs.protocol.LocatedBlock;
import org.apache.hadoop.hdfs.security.token.block.SecurityTestUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
public class TestSaslDataTransferExpiredBlockToken extends SaslDataTransferTestCase {
private static final int BLOCK_SIZE = 4096;
private static final int FILE_SIZE = 2 * BLOCK_SIZE;
private static final Path PATH = new Path("/file1");
private final byte[] rawData = new byte[FILE_SIZE];
private MiniDFSCluster cluster;
@Rule
public Timeout timeout = new Timeout(60, TimeUnit.SECONDS);
@Before
public void before() throws Exception {
Random r = new Random();
r.nextBytes(rawData);
HdfsConfiguration conf = createSecureConfig("authentication,integrity,privacy");
cluster = new MiniDFSCluster.Builder(conf).numDataNodes(3).build();
cluster.waitActive();
try (FileSystem fs = cluster.getFileSystem()) {
createFile(fs);
}
// set a short token lifetime (1 second) initially
SecurityTestUtil.setBlockTokenLifetime(
cluster.getNameNode().getNamesystem().getBlockManager().getBlockTokenSecretManager(),
1000L);
}
@After
public void shutdown() {
if (cluster != null) {
cluster.shutdown();
cluster = null;
}
}
private void createFile(FileSystem fs) throws IOException {
try (FSDataOutputStream out = fs.create(PATH)) {
out.write(rawData);
}
}
// read a file using blockSeekTo()
private boolean checkFile1(FSDataInputStream in) {
byte[] toRead = new byte[FILE_SIZE];
int totalRead = 0;
int nRead = 0;
try {
while ((nRead = in.read(toRead, totalRead, toRead.length - totalRead)) > 0) {
totalRead += nRead;
}
} catch (IOException e) {
return false;
}
assertEquals("Cannot read file.", toRead.length, totalRead);
return checkFile(toRead);
}
// read a file using fetchBlockByteRange()/hedgedFetchBlockByteRange()
private boolean checkFile2(FSDataInputStream in) {
byte[] toRead = new byte[FILE_SIZE];
try {
assertEquals("Cannot read file", toRead.length, in.read(0, toRead, 0, toRead.length));
} catch (IOException e) {
return false;
}
return checkFile(toRead);
}
private boolean checkFile(byte[] fileToCheck) {
if (fileToCheck.length != rawData.length) {
return false;
}
for (int i = 0; i < fileToCheck.length; i++) {
if (fileToCheck[i] != rawData[i]) {
return false;
}
}
return true;
}
private FileSystem newFileSystem() throws IOException {
Configuration clientConf = new Configuration(cluster.getConfiguration(0));
clientConf.setInt(Retry.WINDOW_BASE_KEY, Integer.MAX_VALUE);
return FileSystem.newInstance(cluster.getURI(), clientConf);
}
private FileSystem newFileSystemHedgedRead() throws IOException {
Configuration clientConf = new Configuration(cluster.getConfiguration(0));
clientConf.setInt(Retry.WINDOW_BASE_KEY, 3000);
clientConf.setInt(HedgedRead.THREADPOOL_SIZE_KEY, 5);
return FileSystem.newInstance(cluster.getURI(), clientConf);
}
@Test
public void testBlockSeekToWithExpiredToken() throws Exception {
// read using blockSeekTo(). Acquired tokens are cached in in
try (FileSystem fs = newFileSystem(); FSDataInputStream in = fs.open(PATH)) {
waitBlockTokenExpired(in);
assertTrue(checkFile1(in));
}
}
@Test
public void testFetchBlockByteRangeWithExpiredToken() throws Exception {
// read using fetchBlockByteRange(). Acquired tokens are cached in in
try (FileSystem fs = newFileSystem(); FSDataInputStream in = fs.open(PATH)) {
waitBlockTokenExpired(in);
assertTrue(checkFile2(in));
}
}
@Test
public void testHedgedFetchBlockByteRangeWithExpiredToken() throws Exception {
// read using hedgedFetchBlockByteRange(). Acquired tokens are cached in in
try (FileSystem fs = newFileSystemHedgedRead(); FSDataInputStream in = fs.open(PATH)) {
waitBlockTokenExpired(in);
assertTrue(checkFile2(in));
}
}
private void waitBlockTokenExpired(FSDataInputStream in1) throws Exception {
DFSInputStream innerStream = (DFSInputStream) in1.getWrappedStream();
for (LocatedBlock block : innerStream.getAllBlocks()) {
while (!SecurityTestUtil.isBlockTokenExpired(block.getBlockToken())) {
Thread.sleep(100);
}
}
}
}