HBASE-19920 Lazy init for ProtobufUtil classloader

This commit is contained in:
Mike Drob 2018-02-08 16:12:46 -06:00
parent 70d3413ee2
commit 138f82c8c5
5 changed files with 146 additions and 35 deletions

View File

@ -19,6 +19,8 @@ package org.apache.hadoop.hbase.ipc;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedAction;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.DoNotRetryIOException; import org.apache.hadoop.hbase.DoNotRetryIOException;
@ -36,19 +38,24 @@ import org.apache.hadoop.ipc.RemoteException;
*/ */
@SuppressWarnings("serial") @SuppressWarnings("serial")
@InterfaceAudience.Public @InterfaceAudience.Public
@edu.umd.cs.findbugs.annotations.SuppressWarnings(
value = "DP_CREATE_CLASSLOADER_INSIDE_DO_PRIVILEGED", justification = "None. Address sometime.")
public class RemoteWithExtrasException extends RemoteException { public class RemoteWithExtrasException extends RemoteException {
private final String hostname; private final String hostname;
private final int port; private final int port;
private final boolean doNotRetry; private final boolean doNotRetry;
private final static ClassLoader CLASS_LOADER; /**
* Dynamic class loader to load filter/comparators
*/
private final static class ClassLoaderHolder {
private final static ClassLoader CLASS_LOADER;
static { static {
ClassLoader parent = RemoteWithExtrasException.class.getClassLoader(); ClassLoader parent = RemoteWithExtrasException.class.getClassLoader();
Configuration conf = HBaseConfiguration.create(); Configuration conf = HBaseConfiguration.create();
CLASS_LOADER = new DynamicClassLoader(conf, parent); CLASS_LOADER = AccessController.doPrivileged((PrivilegedAction<ClassLoader>)
() -> new DynamicClassLoader(conf, parent)
);
}
} }
public RemoteWithExtrasException(String className, String msg, final boolean doNotRetry) { public RemoteWithExtrasException(String className, String msg, final boolean doNotRetry) {
@ -69,7 +76,7 @@ public class RemoteWithExtrasException extends RemoteException {
try { try {
// try to load a exception class from where the HBase classes are loaded or from Dynamic // try to load a exception class from where the HBase classes are loaded or from Dynamic
// classloader. // classloader.
realClass = Class.forName(getClassName(), false, CLASS_LOADER); realClass = Class.forName(getClassName(), false, ClassLoaderHolder.CLASS_LOADER);
} catch (ClassNotFoundException cnfe) { } catch (ClassNotFoundException cnfe) {
try { try {
// cause could be a hadoop exception, try to load from hadoop classpath // cause could be a hadoop exception, try to load from hadoop classpath

View File

@ -31,6 +31,8 @@ import com.google.protobuf.TextFormat;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -113,8 +115,6 @@ import org.apache.yetus.audience.InterfaceAudience;
* @see ProtobufUtil * @see ProtobufUtil
*/ */
// TODO: Generate this class from the shaded version. // TODO: Generate this class from the shaded version.
@edu.umd.cs.findbugs.annotations.SuppressWarnings(
value="DP_CREATE_CLASSLOADER_INSIDE_DO_PRIVILEGED", justification="None. Address sometime.")
@InterfaceAudience.Private // TODO: some clients (Hive, etc) use this class. @InterfaceAudience.Private // TODO: some clients (Hive, etc) use this class.
public final class ProtobufUtil { public final class ProtobufUtil {
private ProtobufUtil() { private ProtobufUtil() {
@ -168,14 +168,18 @@ public final class ProtobufUtil {
} }
/** /**
* Dynamic class loader to load filter/comparators * Dynamic class loader to load filter/comparators
*/ */
private final static ClassLoader CLASS_LOADER; private final static class ClassLoaderHolder {
private final static ClassLoader CLASS_LOADER;
static { static {
ClassLoader parent = ProtobufUtil.class.getClassLoader(); ClassLoader parent = ProtobufUtil.class.getClassLoader();
Configuration conf = HBaseConfiguration.create(); Configuration conf = HBaseConfiguration.create();
CLASS_LOADER = new DynamicClassLoader(conf, parent); CLASS_LOADER = AccessController.doPrivileged((PrivilegedAction<ClassLoader>)
() -> new DynamicClassLoader(conf, parent)
);
}
} }
/** /**
@ -1430,8 +1434,7 @@ public final class ProtobufUtil {
String funcName = "parseFrom"; String funcName = "parseFrom";
byte [] value = proto.getSerializedComparator().toByteArray(); byte [] value = proto.getSerializedComparator().toByteArray();
try { try {
Class<? extends ByteArrayComparable> c = Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
(Class<? extends ByteArrayComparable>)Class.forName(type, true, CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class); Method parseFrom = c.getMethod(funcName, byte[].class);
if (parseFrom == null) { if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type); throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
@ -1454,8 +1457,7 @@ public final class ProtobufUtil {
final byte [] value = proto.getSerializedFilter().toByteArray(); final byte [] value = proto.getSerializedFilter().toByteArray();
String funcName = "parseFrom"; String funcName = "parseFrom";
try { try {
Class<? extends Filter> c = Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
(Class<? extends Filter>)Class.forName(type, true, CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class); Method parseFrom = c.getMethod(funcName, byte[].class);
if (parseFrom == null) { if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type); throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
@ -1541,7 +1543,7 @@ public final class ProtobufUtil {
String type = parameter.getName(); String type = parameter.getName();
try { try {
Class<? extends Throwable> c = Class<? extends Throwable> c =
(Class<? extends Throwable>)Class.forName(type, true, CLASS_LOADER); (Class<? extends Throwable>)Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
Constructor<? extends Throwable> cn = null; Constructor<? extends Throwable> cn = null;
try { try {
cn = c.getDeclaredConstructor(String.class); cn = c.getDeclaredConstructor(String.class);

View File

@ -23,6 +23,8 @@ import java.io.InputStream;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -105,6 +107,7 @@ import org.apache.hadoop.hbase.util.VersionInfo;
import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.ipc.RemoteException;
import org.apache.yetus.audience.InterfaceAudience; import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hbase.thirdparty.com.google.common.annotations.VisibleForTesting;
import org.apache.hbase.thirdparty.com.google.common.io.ByteStreams; import org.apache.hbase.thirdparty.com.google.common.io.ByteStreams;
import org.apache.hbase.thirdparty.com.google.gson.JsonArray; import org.apache.hbase.thirdparty.com.google.gson.JsonArray;
import org.apache.hbase.thirdparty.com.google.gson.JsonElement; import org.apache.hbase.thirdparty.com.google.gson.JsonElement;
@ -191,8 +194,6 @@ import org.apache.hadoop.hbase.shaded.protobuf.generated.ZooKeeperProtos;
* @see ProtobufUtil * @see ProtobufUtil
*/ */
// TODO: Generate the non-shaded protobufutil from this one. // TODO: Generate the non-shaded protobufutil from this one.
@edu.umd.cs.findbugs.annotations.SuppressWarnings(
value="DP_CREATE_CLASSLOADER_INSIDE_DO_PRIVILEGED", justification="None. Address sometime.")
@InterfaceAudience.Private // TODO: some clients (Hive, etc) use this class @InterfaceAudience.Private // TODO: some clients (Hive, etc) use this class
public final class ProtobufUtil { public final class ProtobufUtil {
private ProtobufUtil() { private ProtobufUtil() {
@ -244,15 +245,27 @@ public final class ProtobufUtil {
EMPTY_RESULT_PB_STALE = builder.build(); EMPTY_RESULT_PB_STALE = builder.build();
} }
private static volatile boolean classLoaderLoaded = false;
/** /**
* Dynamic class loader to load filter/comparators * Dynamic class loader to load filter/comparators
*/ */
private final static ClassLoader CLASS_LOADER; private final static class ClassLoaderHolder {
private final static ClassLoader CLASS_LOADER;
static { static {
ClassLoader parent = ProtobufUtil.class.getClassLoader(); ClassLoader parent = ProtobufUtil.class.getClassLoader();
Configuration conf = HBaseConfiguration.create(); Configuration conf = HBaseConfiguration.create();
CLASS_LOADER = new DynamicClassLoader(conf, parent); CLASS_LOADER = AccessController.doPrivileged((PrivilegedAction<ClassLoader>)
() -> new DynamicClassLoader(conf, parent)
);
classLoaderLoaded = true;
}
}
@VisibleForTesting
public static boolean isClassLoaderLoaded() {
return classLoaderLoaded;
} }
/** /**
@ -1586,8 +1599,7 @@ public final class ProtobufUtil {
String funcName = "parseFrom"; String funcName = "parseFrom";
byte [] value = proto.getSerializedComparator().toByteArray(); byte [] value = proto.getSerializedComparator().toByteArray();
try { try {
Class<? extends ByteArrayComparable> c = Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
(Class<? extends ByteArrayComparable>)Class.forName(type, true, CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class); Method parseFrom = c.getMethod(funcName, byte[].class);
if (parseFrom == null) { if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type); throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
@ -1610,8 +1622,7 @@ public final class ProtobufUtil {
final byte [] value = proto.getSerializedFilter().toByteArray(); final byte [] value = proto.getSerializedFilter().toByteArray();
String funcName = "parseFrom"; String funcName = "parseFrom";
try { try {
Class<? extends Filter> c = Class<?> c = Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
(Class<? extends Filter>)Class.forName(type, true, CLASS_LOADER);
Method parseFrom = c.getMethod(funcName, byte[].class); Method parseFrom = c.getMethod(funcName, byte[].class);
if (parseFrom == null) { if (parseFrom == null) {
throw new IOException("Unable to locate function: " + funcName + " in type: " + type); throw new IOException("Unable to locate function: " + funcName + " in type: " + type);
@ -1697,7 +1708,7 @@ public final class ProtobufUtil {
String type = parameter.getName(); String type = parameter.getName();
try { try {
Class<? extends Throwable> c = Class<? extends Throwable> c =
(Class<? extends Throwable>)Class.forName(type, true, CLASS_LOADER); (Class<? extends Throwable>)Class.forName(type, true, ClassLoaderHolder.CLASS_LOADER);
Constructor<? extends Throwable> cn = null; Constructor<? extends Throwable> cn = null;
try { try {
cn = c.getDeclaredConstructor(String.class); cn = c.getDeclaredConstructor(String.class);

View File

@ -26,7 +26,6 @@ import com.google.protobuf.ByteString;
import com.google.protobuf.ServiceException; import com.google.protobuf.ServiceException;
import org.apache.hadoop.hbase.zookeeper.ZKWatcher; import org.apache.hadoop.hbase.zookeeper.ZKWatcher;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.TableName; import org.apache.hadoop.hbase.TableName;
@ -41,6 +40,7 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.yetus.audience.InterfaceAudience;
import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.KeeperException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -53,6 +53,15 @@ public class TokenUtil {
// This class is referenced indirectly by User out in common; instances are created by reflection // This class is referenced indirectly by User out in common; instances are created by reflection
private static final Logger LOG = LoggerFactory.getLogger(TokenUtil.class); private static final Logger LOG = LoggerFactory.getLogger(TokenUtil.class);
// Set in TestTokenUtil via reflection
private static ServiceException injectedException;
private static void injectFault() throws ServiceException {
if (injectedException != null) {
throw injectedException;
}
}
/** /**
* Obtain and return an authentication token for the current user. * Obtain and return an authentication token for the current user.
* @param conn The HBase cluster connection * @param conn The HBase cluster connection
@ -63,6 +72,8 @@ public class TokenUtil {
Connection conn) throws IOException { Connection conn) throws IOException {
Table meta = null; Table meta = null;
try { try {
injectFault();
meta = conn.getTable(TableName.META_TABLE_NAME); meta = conn.getTable(TableName.META_TABLE_NAME);
CoprocessorRpcChannel rpcChannel = meta.coprocessorService(HConstants.EMPTY_START_ROW); CoprocessorRpcChannel rpcChannel = meta.coprocessorService(HConstants.EMPTY_START_ROW);
AuthenticationProtos.AuthenticationService.BlockingInterface service = AuthenticationProtos.AuthenticationService.BlockingInterface service =

View File

@ -0,0 +1,80 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security.token;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.net.URLClassLoader;
import org.apache.hadoop.hbase.HBaseClassTestRule;
import org.apache.hadoop.hbase.client.Connection;
import org.apache.hadoop.hbase.testclassification.SmallTests;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil;
@Category(SmallTests.class)
public class TestTokenUtil {
@ClassRule
public static final HBaseClassTestRule CLASS_RULE =
HBaseClassTestRule.forClass(TestTokenUtil.class);
@Test
public void testObtainToken() throws Exception {
URL urlPU = ProtobufUtil.class.getProtectionDomain().getCodeSource().getLocation();
URL urlTU = TokenUtil.class.getProtectionDomain().getCodeSource().getLocation();
ClassLoader cl = new URLClassLoader(new URL[] { urlPU, urlTU }, getClass().getClassLoader());
Throwable injected = new com.google.protobuf.ServiceException("injected");
Class<?> tokenUtil = cl.loadClass(TokenUtil.class.getCanonicalName());
Field shouldInjectFault = tokenUtil.getDeclaredField("injectedException");
shouldInjectFault.setAccessible(true);
shouldInjectFault.set(null, injected);
try {
tokenUtil.getMethod("obtainToken", Connection.class)
.invoke(null, new Object[] { null });
fail("Should have injected exception.");
} catch (InvocationTargetException e) {
Throwable t = e;
boolean serviceExceptionFound = false;
while ((t = t.getCause()) != null) {
if (t == injected) { // reference equality
serviceExceptionFound = true;
break;
}
}
if (!serviceExceptionFound) {
throw e; // wrong exception, fail the test
}
}
Boolean loaded = (Boolean) cl.loadClass(ProtobufUtil.class.getCanonicalName())
.getDeclaredMethod("isClassLoaderLoaded")
.invoke(null);
assertFalse("Should not have loaded DynamicClassLoader", loaded);
}
}