NIFI-10375: If a class is not allowed in the AllowListClassLoader by classname, check the class's module and allow through anything in the java. or jdk. modules.

NIFI-10375: Addressed review feedback: removed loading of classes from JMODs in the StatelessBootstrap because it only was relevant if using JDK. Instead, just inspect the module as we do when using the JRE. Also addressed issue of allow NoClassDefFoundError fly when we should use ClassNotFoundException

This closes #6317.

Signed-off-by: Peter Turcsanyi <turcsanyi@apache.org>
This commit is contained in:
Mark Payne 2022-08-19 15:45:33 -04:00 committed by Peter Turcsanyi
parent 5303dd13aa
commit 6b424c3fd3
No known key found for this signature in database
GPG Key ID: 55A813F1C3E553DC
3 changed files with 155 additions and 32 deletions

View File

@ -17,7 +17,13 @@
package org.apache.nifi.stateless.bootstrap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
/**
@ -37,36 +43,94 @@ import java.util.Set;
* </p>
*/
public class AllowListClassLoader extends ClassLoader {
private final Set<String> allowed;
private static final Logger logger = LoggerFactory.getLogger(AllowListClassLoader.class);
private final Set<String> allowedClassNames;
private final List<String> allowedModulePrefixes = Arrays.asList("java.", "jdk.");
public AllowListClassLoader(final ClassLoader parent, final Set<String> allowed) {
super(parent);
this.allowed = allowed;
this.allowedClassNames = allowed;
}
/**
* @return the set of all Class names that will not be blocked from loading by the parent
*/
public Set<String> getClassesAllowed() {
return Collections.unmodifiableSet(allowed);
return Collections.unmodifiableSet(allowedClassNames);
}
@Override
protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
if (allowed.contains(name)) {
if (allowedClassNames.contains(name)) {
return super.loadClass(name, resolve);
}
try {
final Class<?> found = super.loadClass(name, false);
final boolean allowed = isClassAllowed(name, found);
if (allowed) {
if (resolve) {
super.resolveClass(found);
}
return found;
}
} catch (final NoClassDefFoundError ncdfe) {
// Allow the code to 'fall through' to the ClassNotFoundException below.
}
throw new ClassNotFoundException(name + " was blocked by AllowListClassLoader");
}
@Override
protected Class<?> findClass(final String name) throws ClassNotFoundException {
if (allowed.contains(name)) {
return super.findClass(name);
final Class<?> found = super.findClass(name);
if (isClassAllowed(name, found)) {
return found;
}
throw new ClassNotFoundException(name + " was blocked by AllowListClassLoader");
}
private boolean isClassAllowed(final String name, final Class<?> clazz) {
// If the name of the class is in the allowed class names, allow it.
if (allowedClassNames.contains(name)) {
return true;
}
// If the class has a module whose name is allowed, allow it.
// The module is obtained by calling Class.getModule(). However, that method is only available in Java 9.
// Since this codebase must be Java 8 compatible we can't make that method call. So we use Reflection to determine
// if the getModule method exists (which it will for Java 9+ but not Java 1.8), and if so get the name of the module.
try {
final Method getModule = Class.class.getMethod("getModule");
final Object module = getModule.invoke(clazz);
if (module == null) {
return false;
}
final Method getName = module.getClass().getMethod("getName");
final String moduleName = (String) getName.invoke(module);
if (isModuleAllowed(moduleName)) {
logger.debug("Allowing Class {} because its module is {}", name, moduleName);
return true;
}
return false;
} catch (final Exception e) {
logger.debug("Failed to determine if class {} is part of the implicitly allowed modules", name, e);
return false;
}
}
private boolean isModuleAllowed(final String moduleName) {
for (final String prefix : allowedModulePrefixes) {
if (moduleName.startsWith(prefix)) {
return true;
}
}
return false;
}
}

View File

@ -21,9 +21,9 @@ import org.apache.nifi.bundle.Bundle;
import org.apache.nifi.bundle.BundleCoordinate;
import org.apache.nifi.nar.NarClassLoader;
import org.apache.nifi.nar.NarClassLoaders;
import org.apache.nifi.nar.NarUnpackMode;
import org.apache.nifi.nar.NarUnpacker;
import org.apache.nifi.nar.SystemBundle;
import org.apache.nifi.nar.NarUnpackMode;
import org.apache.nifi.stateless.config.ParameterOverride;
import org.apache.nifi.stateless.config.StatelessConfigurationException;
import org.apache.nifi.stateless.engine.NarUnpackLock;
@ -37,8 +37,11 @@ import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
@ -50,7 +53,6 @@ import java.util.function.Predicate;
import java.util.jar.JarFile;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
public class StatelessBootstrap {
private static final Logger logger = LoggerFactory.getLogger(StatelessBootstrap.class);
@ -147,7 +149,7 @@ public class StatelessBootstrap {
* @param parent the parent class loader that the given BlockListClassLoader should delegate to for classes that it does not block
* @return an AllowListClassLoader that allows only the appropriate classes to be loaded from the given parent
*/
private static AllowListClassLoader createExtensionRootClassLoader(final File narDirectory, final ClassLoader parent) throws IOException {
protected static AllowListClassLoader createExtensionRootClassLoader(final File narDirectory, final ClassLoader parent) throws IOException {
final File[] narDirectoryFiles = narDirectory.listFiles();
if (narDirectoryFiles == null) {
throw new IOException("Could not get a listing of the NAR directory");
@ -183,6 +185,12 @@ public class StatelessBootstrap {
logger.debug("The following class/JAR files will be explicitly allowed to be loaded by Stateless Extensions ClassLoaders from parent {}: {}", parent, filesAllowed);
logger.debug("The following JAR/JMOD files from ${JAVA_HOME} will be explicitly allowed to be loaded by Stateless Extensions ClassLoaders from parent {}: {}", parent, javaHomeFilenames);
logger.debug("The final list of classes allowed to be loaded by Stateless Extension ClassLoaders from parent {}: {}", parent, classesAllowed);
if (parent instanceof URLClassLoader) {
final URL[] parentUrls = ((URLClassLoader) parent).getURLs();
logger.debug("Parent ClassLoader has the following URLs loaded: {}", Arrays.asList(parentUrls));
} else {
logger.debug("Parent ClassLoader is not a URLClassLoader: {} / {}", parent, parent.getClass());
}
final AllowListClassLoader allowListClassLoader = new AllowListClassLoader(parent, classesAllowed);
return allowListClassLoader;
@ -200,12 +208,18 @@ public class StatelessBootstrap {
logger.warn("System property for java.home is {} but that directory does not exist so will not allow any classes explicitly from java.home in AllowListClassLoader", javaHomeValue);
return Collections.emptySet();
}
logger.debug("Java Home Directory is {}", javaHome.getAbsolutePath());
final File[] javaHomeFiles = javaHome.listFiles();
if (javaHomeFiles == null) {
logger.warn("System property for java.home is {} but that directory is not readable so will not allow any classes explicitly from java.home in AllowListClassLoader", javaHomeValue);
return Collections.emptySet();
}
if (logger.isDebugEnabled()) {
logger.debug("Found the following files in Java Home: {}", Arrays.asList(javaHomeFiles));
logger.debug("Full listing of Java Home:");
logFullJavaHomeListing(javaHomeFiles);
}
final Set<File> loadableFiles = new HashSet<>();
for (final File file : javaHomeFiles) {
@ -215,10 +229,32 @@ public class StatelessBootstrap {
return loadableFiles;
}
private static void logFullJavaHomeListing(final File[] files) {
if (files == null) {
return;
}
for (final File file : files) {
if (file.isDirectory()) {
logger.debug(file.getAbsolutePath() + "/");
final File[] children = file.listFiles();
if (children == null) {
logger.debug("Failed to perform listing of directory {}", file);
continue;
}
logFullJavaHomeListing(children);
} else {
logger.debug(file.getAbsolutePath());
}
}
}
private static void findLoadableFiles(final File file, final Set<File> loadable) {
if (file.isDirectory()) {
final File[] children = file.listFiles();
if (children == null) {
logger.debug("Unable to obtain listing of files for directory {}", file.getAbsolutePath());
return;
}
@ -230,7 +266,7 @@ public class StatelessBootstrap {
}
final String filename = file.getName();
if (filename.endsWith(".jar") || filename.endsWith(".jmod")) {
if (filename.endsWith(".jar")) {
loadable.add(file);
}
}
@ -239,8 +275,6 @@ public class StatelessBootstrap {
final String filename = file.getName();
if (filename.endsWith(".jar")) {
findClassNamesInJar(file, classNames);
} else if (filename.endsWith(".jmod")) {
findClassesInJmod(file, classNames);
}
}
@ -294,26 +328,6 @@ public class StatelessBootstrap {
}
}
private static void findClassesInJmod(final File file, final Set<String> classNames) throws IOException {
if (!file.getName().endsWith(".jmod") || !file.isFile() || !file.exists()) {
return;
}
try (final ZipFile zipFile = new ZipFile(file)) {
final Enumeration<? extends ZipEntry> enumeration = zipFile.entries();
while (enumeration.hasMoreElements()) {
final ZipEntry zipEntry = enumeration.nextElement();
final String entryName = zipEntry.getName();
if (entryName.startsWith("classes/") && entryName.endsWith(".class")) {
final int lastIndex = entryName.lastIndexOf(".class");
final String className = entryName.substring(8, lastIndex).replace("/", ".");
classNames.add(className);
}
}
}
}
private static File locateStatelessNarWorkingDirectory(final File workingDirectory) throws IOException {
final File[] files = workingDirectory.listFiles();
if (files == null) {

View File

@ -17,13 +17,18 @@
package org.apache.nifi.stateless.bootstrap;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.HashSet;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestStatelessBootstrap {
@ -42,4 +47,44 @@ public class TestStatelessBootstrap {
assertTrue(fileNames.contains("FakeBootstrap.class"));
}
@Test
public void testClassloaderLoadsJavaLangObject() throws IOException, ClassNotFoundException {
final File narDirectory = new File("target");
try (final URLClassLoader systemClassLoader = new URLClassLoader(new URL[0])) {
final AllowListClassLoader allowListClassLoader = StatelessBootstrap.createExtensionRootClassLoader(narDirectory, systemClassLoader);
final Class<?> objectClass = allowListClassLoader.loadClass("java.lang.Object", true);
assertNotNull(objectClass);
}
}
@Test
public void testClassNotAllowed() throws IOException, ClassNotFoundException {
// Specify a class that should be loaded by the system class loader
final File classFile = new File("target/classes");
final URL classUrl = classFile.toURI().toURL();
final String classToLoad = "org.apache.nifi.stateless.bootstrap.RunStatelessFlow";
// A directory for NARs, jars, etc. that are allowed by the AllowListClassLoader
final File narDirectory = new File("target");
// Create a URLClassLoader to use for the System ClassLoader. This will load the classes from the target/ directory.
// Then create an AllowListClassLoader that will not allow these classes through.
// Ensure that the classes are not allowed through, but that classes in the java.lang still are.
try (final URLClassLoader systemClassLoader = new URLClassLoader(new URL[] {classUrl})) {
final AllowListClassLoader allowListClassLoader = StatelessBootstrap.createExtensionRootClassLoader(narDirectory, systemClassLoader);
final Class<?> classFromSystemLoader = systemClassLoader.loadClass(classToLoad);
assertNotNull(classFromSystemLoader);
allowListClassLoader.loadClass("java.util.logging.Logger", true);
Assertions.assertThrows(Exception.class, () -> {
allowListClassLoader.loadClass(classToLoad);
});
final Class<?> objectClass = allowListClassLoader.loadClass("java.lang.Object", true);
assertNotNull(objectClass);
}
}
}