HADOOP-17346. Fair call queue is defeated by abusive service principals. Contributed by Ahmed Hussein (ahussein).

This commit is contained in:
Eric Payne 2020-11-23 21:19:07 +00:00
parent 6062978768
commit c7845dc574
10 changed files with 172 additions and 30 deletions

View File

@ -33,6 +33,7 @@ import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.security.UserGroupInformation;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -201,6 +202,19 @@ public class CallQueueManager<E extends Schedulable>
return scheduler.getPriorityLevel(e); return scheduler.getPriorityLevel(e);
} }
int getPriorityLevel(UserGroupInformation user) {
if (scheduler instanceof DecayRpcScheduler) {
return ((DecayRpcScheduler)scheduler).getPriorityLevel(user);
}
return 0;
}
void setPriorityLevel(UserGroupInformation user, int priority) {
if (scheduler instanceof DecayRpcScheduler) {
((DecayRpcScheduler)scheduler).setPriorityLevel(user, priority);
}
}
void setClientBackoffEnabled(boolean value) { void setClientBackoffEnabled(boolean value) {
clientBackOffEnabled = value; clientBackOffEnabled = value;
} }

View File

@ -51,6 +51,7 @@ import org.apache.hadoop.metrics2.lib.Interns;
import org.apache.hadoop.metrics2.util.MBeans; import org.apache.hadoop.metrics2.util.MBeans;
import org.apache.hadoop.metrics2.util.Metrics2Util.NameValuePair; import org.apache.hadoop.metrics2.util.Metrics2Util.NameValuePair;
import org.apache.hadoop.metrics2.util.Metrics2Util.TopN; import org.apache.hadoop.metrics2.util.Metrics2Util.TopN;
import org.apache.hadoop.security.UserGroupInformation;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -172,7 +173,7 @@ public class DecayRpcScheduler implements RpcScheduler,
private static final double PRECISION = 0.0001; private static final double PRECISION = 0.0001;
private MetricsProxy metricsProxy; private MetricsProxy metricsProxy;
private final CostProvider costProvider; private final CostProvider costProvider;
private final Map<String,Integer> staticPriorities = new HashMap<>();
/** /**
* This TimerTask will call decayCurrentCosts until * This TimerTask will call decayCurrentCosts until
* the scheduler has been garbage collected. * the scheduler has been garbage collected.
@ -468,7 +469,7 @@ public class DecayRpcScheduler implements RpcScheduler,
AtomicLong value = entry.getValue().get(0); AtomicLong value = entry.getValue().get(0);
long snapshot = value.get(); long snapshot = value.get();
int computedLevel = computePriorityLevel(snapshot); int computedLevel = computePriorityLevel(snapshot, id);
nextCache.put(id, computedLevel); nextCache.put(id, computedLevel);
} }
@ -515,10 +516,15 @@ public class DecayRpcScheduler implements RpcScheduler,
/** /**
* Given the cost for an identity, compute a scheduling decision. * Given the cost for an identity, compute a scheduling decision.
* *
* @param identity to compute a cost
* @param cost the cost for an identity * @param cost the cost for an identity
* @return scheduling decision from 0 to numLevels - 1 * @return scheduling decision from 0 to numLevels - 1
*/ */
private int computePriorityLevel(long cost) { private int computePriorityLevel(long cost, Object identity) {
Integer staticPriority = staticPriorities.get(identity);
if (staticPriority != null) {
return staticPriority.intValue();
}
long totalCallSnapshot = totalDecayedCallCost.get(); long totalCallSnapshot = totalDecayedCallCost.get();
double proportion = 0; double proportion = 0;
@ -558,11 +564,20 @@ public class DecayRpcScheduler implements RpcScheduler,
// Cache was no good, compute it // Cache was no good, compute it
List<AtomicLong> costList = callCosts.get(identity); List<AtomicLong> costList = callCosts.get(identity);
long currentCost = costList == null ? 0 : costList.get(0).get(); long currentCost = costList == null ? 0 : costList.get(0).get();
int priority = computePriorityLevel(currentCost); int priority = computePriorityLevel(currentCost, identity);
LOG.debug("compute priority for {} priority {}", identity, priority); LOG.debug("compute priority for {} priority {}", identity, priority);
return priority; return priority;
} }
private String getIdentity(Schedulable obj) {
String identity = this.identityProvider.makeIdentity(obj);
if (identity == null) {
// Identity provider did not handle this
identity = DECAYSCHEDULER_UNKNOWN_IDENTITY;
}
return identity;
}
/** /**
* Compute the appropriate priority for a schedulable based on past requests. * Compute the appropriate priority for a schedulable based on past requests.
* @param obj the schedulable obj to query and remember * @param obj the schedulable obj to query and remember
@ -571,15 +586,41 @@ public class DecayRpcScheduler implements RpcScheduler,
@Override @Override
public int getPriorityLevel(Schedulable obj) { public int getPriorityLevel(Schedulable obj) {
// First get the identity // First get the identity
String identity = this.identityProvider.makeIdentity(obj); String identity = getIdentity(obj);
if (identity == null) { // highest priority users may have a negative priority but their
// Identity provider did not handle this // calls will be priority 0.
identity = DECAYSCHEDULER_UNKNOWN_IDENTITY; return Math.max(0, cachedOrComputedPriorityLevel(identity));
} }
@VisibleForTesting
int getPriorityLevel(UserGroupInformation ugi) {
String identity = getIdentity(newSchedulable(ugi));
// returns true priority of the user.
return cachedOrComputedPriorityLevel(identity); return cachedOrComputedPriorityLevel(identity);
} }
@VisibleForTesting
void setPriorityLevel(UserGroupInformation ugi, int priority) {
String identity = getIdentity(newSchedulable(ugi));
priority = Math.min(numLevels - 1, priority);
LOG.info("Setting priority for user:" + identity + "=" + priority);
staticPriorities.put(identity, priority);
}
// dummy instance to conform to identity provider api.
private static Schedulable newSchedulable(UserGroupInformation ugi) {
return new Schedulable() {
@Override
public UserGroupInformation getUserGroupInformation() {
return ugi;
}
@Override
public int getPriorityLevel() {
return 0;
}
};
}
@Override @Override
public boolean shouldBackOff(Schedulable obj) { public boolean shouldBackOff(Schedulable obj) {
Boolean backOff = false; Boolean backOff = false;

View File

@ -49,6 +49,7 @@ import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.Rpc
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.SaslRpcServer; import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
@ -946,7 +947,18 @@ public class RPC {
" ProtocolImpl=" + protocolImpl.getClass().getName() + " ProtocolImpl=" + protocolImpl.getClass().getName() +
" protocolClass=" + protocolClass.getName()); " protocolClass=" + protocolClass.getName());
} }
} String client = SecurityUtil.getClientPrincipal(protocolClass, getConf());
if (client != null) {
// notify the server's rpc scheduler that the protocol user has
// highest priority. the scheduler should exempt the user from
// priority calculations.
try {
setPriorityLevel(UserGroupInformation.createRemoteUser(client), -1);
} catch (Exception ex) {
LOG.warn("Failed to set scheduling priority for " + client, ex);
}
}
}
static class VerProtocolImpl { static class VerProtocolImpl {
final long version; final long version;

View File

@ -638,7 +638,22 @@ public abstract class Server {
address.getPort(), e); address.getPort(), e);
} }
} }
@VisibleForTesting
int getPriorityLevel(Schedulable e) {
return callQueue.getPriorityLevel(e);
}
@VisibleForTesting
int getPriorityLevel(UserGroupInformation ugi) {
return callQueue.getPriorityLevel(ugi);
}
@VisibleForTesting
void setPriorityLevel(UserGroupInformation ugi, int priority) {
callQueue.setPriorityLevel(ugi, priority);
}
/** /**
* Returns a handle to the rpcMetrics (required in tests) * Returns a handle to the rpcMetrics (required in tests)
* @return rpc metrics * @return rpc metrics

View File

@ -31,6 +31,6 @@ public class UserIdentityProvider implements IdentityProvider {
return null; return null;
} }
return ugi.getUserName(); return ugi.getShortUserName();
} }
} }

View File

@ -377,7 +377,25 @@ public final class SecurityUtil {
} }
return null; return null;
} }
/**
* Look up the client principal for a given protocol. It searches all known
* SecurityInfo providers.
* @param protocol the protocol class to get the information for
* @param conf configuration object
* @return client principal or null if it has no client principal defined.
*/
public static String getClientPrincipal(Class<?> protocol,
Configuration conf) {
String user = null;
KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf);
if (krbInfo != null) {
String key = krbInfo.clientPrincipal();
user = (key != null && !key.isEmpty()) ? conf.get(key) : null;
}
return user;
}
/** /**
* Look up the TokenInfo for a given protocol. It searches all known * Look up the TokenInfo for a given protocol. It searches all known
* SecurityInfo providers. * SecurityInfo providers.

View File

@ -99,22 +99,23 @@ public class ServiceAuthorizationManager {
} }
// get client principal key to verify (if available) // get client principal key to verify (if available)
KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); String clientPrincipal = null;
String clientPrincipal = null;
if (krbInfo != null) {
String clientKey = krbInfo.clientPrincipal(); try {
if (clientKey != null && !clientKey.isEmpty()) { clientPrincipal = SecurityUtil.getClientPrincipal(protocol, conf);
try { if (clientPrincipal != null) {
clientPrincipal = SecurityUtil.getServerPrincipal(
conf.get(clientKey), addr); clientPrincipal = SecurityUtil.getServerPrincipal(clientPrincipal,
} catch (IOException e) { addr);
throw (AuthorizationException) new AuthorizationException(
"Can't figure out Kerberos principal name for connection from "
+ addr + " for user=" + user + " protocol=" + protocol)
.initCause(e);
}
} }
} catch (IOException e) {
throw (AuthorizationException) new AuthorizationException(
"Can't figure out Kerberos principal name for connection from "
+ addr + " for user=" + user + " protocol=" + protocol)
.initCause(e);
} }
if((clientPrincipal != null && !clientPrincipal.equals(user.getUserName())) || if((clientPrincipal != null && !clientPrincipal.equals(user.getUserName())) ||
acls.length != 2 || !acls[0].isUserAllowed(user) || acls[1].isUserAllowed(user)) { acls.length != 2 || !acls[0].isUserAllowed(user) || acls[1].isUserAllowed(user)) {
String cause = clientPrincipal != null ? String cause = clientPrincipal != null ?

View File

@ -42,9 +42,8 @@ import java.util.concurrent.TimeUnit;
public class TestDecayRpcScheduler { public class TestDecayRpcScheduler {
private Schedulable mockCall(String id) { private Schedulable mockCall(String id) {
Schedulable mockCall = mock(Schedulable.class); Schedulable mockCall = mock(Schedulable.class);
UserGroupInformation ugi = mock(UserGroupInformation.class); UserGroupInformation ugi = UserGroupInformation.createRemoteUser(id);
when(ugi.getUserName()).thenReturn(id);
when(mockCall.getUserGroupInformation()).thenReturn(ugi); when(mockCall.getUserGroupInformation()).thenReturn(ugi);
return mockCall; return mockCall;

View File

@ -48,6 +48,8 @@ import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.test.GenericTestUtils; import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MetricsAsserts; import org.apache.hadoop.test.MetricsAsserts;
import org.apache.hadoop.test.MockitoUtil; import org.apache.hadoop.test.MockitoUtil;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.internal.util.reflection.Whitebox; import org.mockito.internal.util.reflection.Whitebox;
@ -1285,6 +1287,43 @@ public class TestRPC extends TestRpcBase {
} }
} }
@Test (timeout=30000)
public void testProtocolUserPriority() throws Exception {
final String ns = CommonConfigurationKeys.IPC_NAMESPACE + ".0";
conf.set(CLIENT_PRINCIPAL_KEY, "clientForProtocol");
Server server = null;
try {
server = setupDecayRpcSchedulerandTestServer(ns + ".");
UserGroupInformation ugi = UserGroupInformation.createRemoteUser("user");
// normal users start with priority 0.
Assert.assertEquals(0, server.getPriorityLevel(ugi));
// calls for a protocol defined client will have priority of 0.
Assert.assertEquals(0, server.getPriorityLevel(newSchedulable(ugi)));
// protocol defined client will have top priority of -1.
ugi = UserGroupInformation.createRemoteUser("clientForProtocol");
Assert.assertEquals(-1, server.getPriorityLevel(ugi));
// calls for a protocol defined client will have priority of 0.
Assert.assertEquals(0, server.getPriorityLevel(newSchedulable(ugi)));
} finally {
stop(server, null);
}
}
private static Schedulable newSchedulable(UserGroupInformation ugi) {
return new Schedulable(){
@Override
public UserGroupInformation getUserGroupInformation() {
return ugi;
}
@Override
public int getPriorityLevel() {
return 0; // doesn't matter.
}
};
}
private Server setupDecayRpcSchedulerandTestServer(String ns) private Server setupDecayRpcSchedulerandTestServer(String ns)
throws Exception { throws Exception {
final int queueSizePerHandler = 3; final int queueSizePerHandler = 3;

View File

@ -62,6 +62,8 @@ public class TestRpcBase {
protected final static String SERVER_PRINCIPAL_KEY = protected final static String SERVER_PRINCIPAL_KEY =
"test.ipc.server.principal"; "test.ipc.server.principal";
protected final static String CLIENT_PRINCIPAL_KEY =
"test.ipc.client.principal";
protected final static String ADDRESS = "0.0.0.0"; protected final static String ADDRESS = "0.0.0.0";
protected final static int PORT = 0; protected final static int PORT = 0;
protected static InetSocketAddress addr; protected static InetSocketAddress addr;
@ -271,7 +273,8 @@ public class TestRpcBase {
} }
} }
@KerberosInfo(serverPrincipal = SERVER_PRINCIPAL_KEY) @KerberosInfo(serverPrincipal = SERVER_PRINCIPAL_KEY,
clientPrincipal = CLIENT_PRINCIPAL_KEY)
@TokenInfo(TestTokenSelector.class) @TokenInfo(TestTokenSelector.class)
@ProtocolInfo(protocolName = "org.apache.hadoop.ipc.TestRpcBase$TestRpcService", @ProtocolInfo(protocolName = "org.apache.hadoop.ipc.TestRpcBase$TestRpcService",
protocolVersion = 1) protocolVersion = 1)