inter-Extension dependency support (#16973)

* update docs for kafka lookup extension to specify correct extension ordering

* fix first line

* test with extension dependencies

* save work on dependency management

* working dependency graph

* working pull

* fix style

* fix style

* remove name

* load extension dependencies recursively

* generate depenencies on classloader creation

* add check for circular dependencies

* fix style

* revert style changes

* remove mutable class loader

* clean up class heirarchy

* extensions loader test working

* add unit tests

* pr comments

* fix unit tests
This commit is contained in:
George Shiqi Wu 2024-09-24 14:17:33 -04:00 committed by GitHub
parent 5c862f6ed9
commit d1bfabbf4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 521 additions and 63 deletions

View File

@ -44,6 +44,7 @@
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-lookups-cached-global</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>

View File

@ -0,0 +1,3 @@
{
"dependsOnDruidExtensions": ["druid-lookups-cached-global"]
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.guice;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.List;
public class DruidExtensionDependencies
{
@JsonProperty("dependsOnDruidExtensions")
private List<String> dependsOnDruidExtensions;
public DruidExtensionDependencies()
{
this.dependsOnDruidExtensions = new ArrayList<>();
}
public DruidExtensionDependencies(@Nonnull final List<String> dependsOnDruidExtensions)
{
this.dependsOnDruidExtensions = dependsOnDruidExtensions;
}
public List<String> getDependsOnDruidExtensions()
{
return dependsOnDruidExtensions;
}
}

View File

@ -24,7 +24,6 @@ import com.google.common.collect.Iterators;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
@ -32,13 +31,13 @@ import java.util.List;
/**
* The ClassLoader that gets used when druid.extensions.useExtensionClassloaderFirst = true.
*/
public class ExtensionFirstClassLoader extends URLClassLoader
public class ExtensionFirstClassLoader extends StandardURLClassLoader
{
private final ClassLoader druidLoader;
public ExtensionFirstClassLoader(final URL[] urls, final ClassLoader druidLoader)
public ExtensionFirstClassLoader(final URL[] urls, final ClassLoader druidLoader, final List<ClassLoader> extensionDependencyClassLoaders)
{
super(urls, null);
super(urls, null, extensionDependencyClassLoaders);
this.druidLoader = Preconditions.checkNotNull(druidLoader, "druidLoader");
}
@ -60,8 +59,13 @@ public class ExtensionFirstClassLoader extends URLClassLoader
clazz = findClass(name);
}
catch (ClassNotFoundException e) {
// Try the Druid classloader. Will throw ClassNotFoundException if the class can't be loaded.
return druidLoader.loadClass(name);
try {
clazz = loadClassFromExtensionDependencies(name);
}
catch (ClassNotFoundException e2) {
// Try the Druid classloader. Will throw ClassNotFoundException if the class can't be loaded.
clazz = druidLoader.loadClass(name);
}
}
}
@ -76,13 +80,18 @@ public class ExtensionFirstClassLoader extends URLClassLoader
@Override
public URL getResource(final String name)
{
final URL resourceFromExtension = super.getResource(name);
URL resourceFromExtension = super.getResource(name);
if (resourceFromExtension != null) {
return resourceFromExtension;
} else {
return druidLoader.getResource(name);
}
resourceFromExtension = getResourceFromExtensionsDependencies(name);
if (resourceFromExtension != null) {
return resourceFromExtension;
}
return druidLoader.getResource(name);
}
@Override
@ -90,6 +99,7 @@ public class ExtensionFirstClassLoader extends URLClassLoader
{
final List<URL> urls = new ArrayList<>();
Iterators.addAll(urls, Iterators.forEnumeration(super.getResources(name)));
addExtensionResources(name, urls);
Iterators.addAll(urls, Iterators.forEnumeration(druidLoader.getResources(name)));
return Iterators.asEnumeration(urls.iterator());
}

View File

@ -19,13 +19,19 @@
package org.apache.druid.guice;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Injector;
import org.apache.commons.io.FileUtils;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import javax.inject.Inject;
@ -34,18 +40,21 @@ import java.io.FilenameFilter;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
/**
@ -59,19 +68,28 @@ import java.util.stream.Collectors;
public class ExtensionsLoader
{
private static final Logger log = new Logger(ExtensionsLoader.class);
public static final String DRUID_EXTENSION_DEPENDENCIES_JSON = "druid-extension-dependencies.json";
private final ExtensionsConfig extensionsConfig;
private final ConcurrentHashMap<Pair<File, Boolean>, URLClassLoader> loaders = new ConcurrentHashMap<>();
private final ObjectMapper objectMapper;
@GuardedBy("this")
private final HashMap<Pair<File, Boolean>, StandardURLClassLoader> loaders = new HashMap<>();
/**
* Map of loaded extensions, keyed by class (or interface).
*/
private final ConcurrentHashMap<Class<?>, Collection<?>> extensions = new ConcurrentHashMap<>();
@GuardedBy("this")
private final HashMap<Class<?>, Collection<?>> extensions = new HashMap<>();
@GuardedBy("this")
@MonotonicNonNull
private File[] extensionFilesToLoad;
@Inject
public ExtensionsLoader(ExtensionsConfig config)
public ExtensionsLoader(ExtensionsConfig config, ObjectMapper objectMapper)
{
this.extensionsConfig = config;
this.objectMapper = objectMapper;
}
public static ExtensionsLoader instance(Injector injector)
@ -92,12 +110,14 @@ public class ExtensionsLoader
*/
public <T> Collection<T> getLoadedImplementations(Class<T> clazz)
{
@SuppressWarnings("unchecked")
Collection<T> retVal = (Collection<T>) extensions.get(clazz);
if (retVal == null) {
return Collections.emptySet();
synchronized (this) {
@SuppressWarnings("unchecked")
Collection<T> retVal = (Collection<T>) extensions.get(clazz);
if (retVal == null) {
return Collections.emptySet();
}
return retVal;
}
return retVal;
}
/**
@ -109,9 +129,11 @@ public class ExtensionsLoader
}
@VisibleForTesting
public Map<Pair<File, Boolean>, URLClassLoader> getLoadersMap()
public Map<Pair<File, Boolean>, StandardURLClassLoader> getLoadersMap()
{
return loaders;
synchronized (this) {
return loaders;
}
}
/**
@ -135,12 +157,14 @@ public class ExtensionsLoader
// In practice, it appears the only place this matters is with DruidModule:
// initialization gets the list of extensions, and two REST API calls later
// ask for the same list.
Collection<?> modules = extensions.computeIfAbsent(
serviceClass,
serviceC -> new ServiceLoadingFromExtensions<>(serviceC).implsToLoad
);
//noinspection unchecked
return (Collection<T>) modules;
synchronized (this) {
Collection<?> modules = extensions.computeIfAbsent(
serviceClass,
serviceC -> new ServiceLoadingFromExtensions<>(serviceC).implsToLoad
);
//noinspection unchecked
return (Collection<T>) modules;
}
}
public Collection<DruidModule> getModules()
@ -159,7 +183,7 @@ public class ExtensionsLoader
*
* @return an array of druid extension files that will be loaded by druid process
*/
public File[] getExtensionFilesToLoad()
public void initializeExtensionFilesToLoad()
{
final File rootExtensionsDir = new File(extensionsConfig.getDirectory());
if (rootExtensionsDir.exists() && !rootExtensionsDir.isDirectory()) {
@ -187,25 +211,98 @@ public class ExtensionsLoader
extensionsToLoad[i++] = extensionDir;
}
}
return extensionsToLoad == null ? new File[]{} : extensionsToLoad;
synchronized (this) {
extensionFilesToLoad = extensionsToLoad == null ? new File[]{} : extensionsToLoad;
}
}
public File[] getExtensionFilesToLoad()
{
synchronized (this) {
if (extensionFilesToLoad == null) {
initializeExtensionFilesToLoad();
}
return extensionFilesToLoad;
}
}
/**
* @param extension The File instance of the extension we want to load
*
* @return a URLClassLoader that loads all the jars on which the extension is dependent
* @return a StandardURLClassLoader that loads all the jars on which the extension is dependent
*/
public URLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst)
public StandardURLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst)
{
return loaders.computeIfAbsent(
Pair.of(extension, useExtensionClassloaderFirst),
k -> makeClassLoaderForExtension(k.lhs, k.rhs)
);
return getClassLoaderForExtension(extension, useExtensionClassloaderFirst, new ArrayList<>());
}
private static URLClassLoader makeClassLoaderForExtension(
/**
* @param extension The File instance of the extension we want to load
* @param useExtensionClassloaderFirst Whether to return a StandardURLClassLoader that checks extension classloaders first for classes
* @param extensionDependencyStack If the extension is requested as a dependency of another extension, a list containing the
* dependency stack of the dependent extension (for checking circular dependencies). Otherwise
* this is a empty list.
* @return a StandardURLClassLoader that loads all the jars on which the extension is dependent
*/
public StandardURLClassLoader getClassLoaderForExtension(File extension, boolean useExtensionClassloaderFirst, List<String> extensionDependencyStack)
{
Pair<File, Boolean> classLoaderKey = Pair.of(extension, useExtensionClassloaderFirst);
synchronized (this) {
StandardURLClassLoader classLoader = loaders.get(classLoaderKey);
if (classLoader == null) {
classLoader = makeClassLoaderWithDruidExtensionDependencies(extension, useExtensionClassloaderFirst, extensionDependencyStack);
loaders.put(classLoaderKey, classLoader);
}
return classLoader;
}
}
private StandardURLClassLoader makeClassLoaderWithDruidExtensionDependencies(File extension, boolean useExtensionClassloaderFirst, List<String> extensionDependencyStack)
{
Optional<DruidExtensionDependencies> druidExtensionDependenciesOptional = getDruidExtensionDependencies(extension);
List<String> druidExtensionDependenciesList = druidExtensionDependenciesOptional.isPresent()
? druidExtensionDependenciesOptional.get().getDependsOnDruidExtensions()
: ImmutableList.of();
List<ClassLoader> extensionDependencyClassLoaders = new ArrayList<>();
for (String druidExtensionDependencyName : druidExtensionDependenciesList) {
Optional<File> extensionDependencyFileOptional = Arrays.stream(getExtensionFilesToLoad())
.filter(file -> file.getName().equals(druidExtensionDependencyName))
.findFirst();
if (!extensionDependencyFileOptional.isPresent()) {
throw new RE(
StringUtils.format(
"Extension [%s] depends on [%s] which is not a valid extension or not loaded.",
extension.getName(),
druidExtensionDependencyName
)
);
}
File extensionDependencyFile = extensionDependencyFileOptional.get();
if (extensionDependencyStack.contains(extensionDependencyFile.getName())) {
extensionDependencyStack.add(extensionDependencyFile.getName());
throw new RE(
StringUtils.format(
"Extension [%s] has a circular druid extension dependency. Dependency stack [%s].",
extensionDependencyStack.get(0),
extensionDependencyStack
)
);
}
extensionDependencyStack.add(extensionDependencyFile.getName());
extensionDependencyClassLoaders.add(
getClassLoaderForExtension(extensionDependencyFile, useExtensionClassloaderFirst, extensionDependencyStack)
);
}
return makeClassLoaderForExtension(extension, useExtensionClassloaderFirst, extensionDependencyClassLoaders);
}
private static StandardURLClassLoader makeClassLoaderForExtension(
final File extension,
final boolean useExtensionClassloaderFirst
final boolean useExtensionClassloaderFirst,
final List<ClassLoader> extensionDependencyClassLoaders
)
{
final Collection<File> jars = FileUtils.listFiles(extension, new String[]{"jar"}, false);
@ -224,9 +321,9 @@ public class ExtensionsLoader
}
if (useExtensionClassloaderFirst) {
return new ExtensionFirstClassLoader(urls, ExtensionsLoader.class.getClassLoader());
return new ExtensionFirstClassLoader(urls, ExtensionsLoader.class.getClassLoader(), extensionDependencyClassLoaders);
} else {
return new URLClassLoader(urls, ExtensionsLoader.class.getClassLoader());
return new StandardURLClassLoader(urls, ExtensionsLoader.class.getClassLoader(), extensionDependencyClassLoaders);
}
}
@ -266,6 +363,45 @@ public class ExtensionsLoader
}
}
private Optional<DruidExtensionDependencies> getDruidExtensionDependencies(File extension)
{
final Collection<File> jars = FileUtils.listFiles(extension, new String[]{"jar"}, false);
DruidExtensionDependencies druidExtensionDependencies = null;
String druidExtensionDependenciesJarName = null;
for (File extensionFile : jars) {
try (JarFile jarFile = new JarFile(extensionFile.getPath())) {
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
String entryName = entry.getName();
if (DRUID_EXTENSION_DEPENDENCIES_JSON.equals(entryName)) {
log.debug("Found extension dependency entry in jar [%s]", extensionFile.getPath());
if (druidExtensionDependenciesJarName != null) {
throw new RE(
StringUtils.format(
"The extension [%s] has multiple jars [%s] [%s] with dependencies in them. Each jar should be in a separate extension directory.",
extension.getName(),
druidExtensionDependenciesJarName,
jarFile.getName()
)
);
}
druidExtensionDependencies = objectMapper.readValue(
jarFile.getInputStream(entry),
DruidExtensionDependencies.class
);
druidExtensionDependenciesJarName = jarFile.getName();
}
}
}
catch (IOException e) {
throw new RE(e, "Failed to get dependencies for extension [%s]", extension);
}
}
return druidExtensionDependencies == null ? Optional.empty() : Optional.of(druidExtensionDependencies);
}
private class ServiceLoadingFromExtensions<T>
{
private final Class<T> serviceClass;
@ -293,17 +429,17 @@ public class ExtensionsLoader
for (File extension : getExtensionFilesToLoad()) {
log.debug("Loading extension [%s] for class [%s]", extension.getName(), serviceClass);
try {
final URLClassLoader loader = getClassLoaderForExtension(
final StandardURLClassLoader loader = getClassLoaderForExtension(
extension,
extensionsConfig.isUseExtensionClassloaderFirst()
);
log.info(
"Loading extension [%s], jars: %s",
"Loading extension [%s], jars: %s. Druid extension dependencies [%s]",
extension.getName(),
Arrays.stream(loader.getURLs())
.map(u -> new File(u.getPath()).getName())
.collect(Collectors.joining(", "))
.collect(Collectors.joining(", ")),
loader.getExtensionDependencyClassLoaders()
);
ServiceLoader.load(serviceClass, loader).forEach(impl -> tryAdd(impl, "local file system"));

View File

@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.guice;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterators;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
/**
* The ClassLoader that gets used when druid.extensions.useExtensionClassloaderFirst = false.
*/
public class StandardURLClassLoader extends URLClassLoader
{
private final List<ClassLoader> extensionDependencyClassLoaders;
public StandardURLClassLoader(final URL[] urls, final ClassLoader druidLoader, final List<ClassLoader> extensionDependencyClassLoaders)
{
super(urls, druidLoader);
this.extensionDependencyClassLoaders = Preconditions.checkNotNull(extensionDependencyClassLoaders, "extensionDependencyClassLoaders");
}
@Override
protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException
{
Class<?> clazz;
try {
clazz = super.loadClass(name, resolve);
}
catch (ClassNotFoundException e) {
clazz = loadClassFromExtensionDependencies(name);
}
if (resolve) {
resolveClass(clazz);
}
return clazz;
}
@Override
public URL getResource(final String name)
{
URL resource = super.getResource(name);
if (resource != null) {
return resource;
}
return getResourceFromExtensionsDependencies(name);
}
@Override
public Enumeration<URL> getResources(final String name) throws IOException
{
final List<URL> urls = new ArrayList<>();
Iterators.addAll(urls, Iterators.forEnumeration(super.getResources(name)));
addExtensionResources(name, urls);
return Iterators.asEnumeration(urls.iterator());
}
protected URL getResourceFromExtensionsDependencies(final String name)
{
URL resourceFromExtension = null;
for (ClassLoader classLoader : extensionDependencyClassLoaders) {
resourceFromExtension = classLoader.getResource(name);
if (resourceFromExtension != null) {
break;
}
}
return resourceFromExtension;
}
protected Class<?> loadClassFromExtensionDependencies(final String name) throws ClassNotFoundException
{
for (ClassLoader classLoader : extensionDependencyClassLoaders) {
try {
return classLoader.loadClass(name);
}
catch (ClassNotFoundException ignored) {
}
}
throw new ClassNotFoundException();
}
protected void addExtensionResources(final String name, List<URL> urls) throws IOException
{
for (ClassLoader classLoader : extensionDependencyClassLoaders) {
Iterators.addAll(urls, Iterators.forEnumeration(classLoader.getResources(name)));
}
}
public List<ClassLoader> getExtensionDependencyClassLoaders()
{
return extensionDependencyClassLoaders;
}
}

View File

@ -21,33 +21,46 @@ package org.apache.druid.guice;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import com.google.inject.Injector;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;
public class ExtensionsLoaderTest
{
@Rule
public final TemporaryFolder temporaryFolder = new TemporaryFolder();
private final ObjectMapper objectMapper = new ObjectMapper();
private final Map<String, byte[]> jarFileContents = ImmutableMap.of(
"jar-resource",
"jar-resource-contents".getBytes(Charset.defaultCharset())
);
private Injector startupInjector()
{
return new StartupInjectorBuilder()
@ -76,7 +89,7 @@ public class ExtensionsLoaderTest
Pair<File, Boolean> key = Pair.of(extensionDir, true);
extnLoader.getLoadersMap()
.put(key, new URLClassLoader(new URL[]{}, ExtensionsLoader.class.getClassLoader()));
.put(key, new StandardURLClassLoader(new URL[]{}, ExtensionsLoader.class.getClassLoader(), ImmutableList.of()));
Collection<DruidModule> modules = extnLoader.getFromExtensions(DruidModule.class);
@ -90,16 +103,18 @@ public class ExtensionsLoaderTest
@Test
public void test06GetClassLoaderForExtension() throws IOException
{
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig());
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper);
final File some_extension_dir = temporaryFolder.newFolder();
final File a_jar = new File(some_extension_dir, "a.jar");
final File b_jar = new File(some_extension_dir, "b.jar");
final File c_jar = new File(some_extension_dir, "c.jar");
a_jar.createNewFile();
b_jar.createNewFile();
c_jar.createNewFile();
final URLClassLoader loader = extnLoader.getClassLoaderForExtension(some_extension_dir, false);
createNewJar(a_jar, jarFileContents);
createNewJar(b_jar, jarFileContents);
createNewJar(c_jar, jarFileContents);
final StandardURLClassLoader loader = extnLoader.getClassLoaderForExtension(some_extension_dir, false);
final URL[] expectedURLs = new URL[]{a_jar.toURI().toURL(), b_jar.toURI().toURL(), c_jar.toURI().toURL()};
final URL[] actualURLs = loader.getURLs();
Arrays.sort(actualURLs, Comparator.comparing(URL::getPath));
@ -109,7 +124,7 @@ public class ExtensionsLoaderTest
@Test
public void testGetLoadedModules()
{
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig());
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper);
Collection<DruidModule> modules = extnLoader.getModules();
HashSet<DruidModule> moduleSet = new HashSet<>(modules);
@ -134,7 +149,7 @@ public class ExtensionsLoaderTest
{
return tmpDir.getAbsolutePath();
}
});
}, objectMapper);
Assert.assertArrayEquals(
"Non-exist root extensionsDir should return an empty array of File",
new File[]{},
@ -155,7 +170,7 @@ public class ExtensionsLoaderTest
return extensionsDir.getAbsolutePath();
}
};
final ExtensionsLoader extnLoader = new ExtensionsLoader(config);
final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper);
extnLoader.getExtensionFilesToLoad();
}
@ -172,7 +187,7 @@ public class ExtensionsLoaderTest
}
};
final ExtensionsLoader extnLoader = new ExtensionsLoader(config);
final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper);
Assert.assertArrayEquals(
"Empty root extensionsDir should return an empty array of File",
new File[]{},
@ -196,7 +211,7 @@ public class ExtensionsLoaderTest
return extensionsDir.getAbsolutePath();
}
};
final ExtensionsLoader extnLoader = new ExtensionsLoader(config);
final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper);
final File mysql_metadata_storage = new File(extensionsDir, "mysql-metadata-storage");
mysql_metadata_storage.mkdir();
@ -231,7 +246,7 @@ public class ExtensionsLoaderTest
return extensionsDir.getAbsolutePath();
}
};
final ExtensionsLoader extnLoader = new ExtensionsLoader(config);
final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper);
final File mysql_metadata_storage = new File(extensionsDir, "mysql-metadata-storage");
final File random_extension = new File(extensionsDir, "random-extensions");
@ -267,7 +282,7 @@ public class ExtensionsLoaderTest
};
final File random_extension = new File(extensionsDir, "random-extensions");
random_extension.mkdir();
final ExtensionsLoader extnLoader = new ExtensionsLoader(config);
final ExtensionsLoader extnLoader = new ExtensionsLoader(config, objectMapper);
extnLoader.getExtensionFilesToLoad();
}
@ -320,14 +335,139 @@ public class ExtensionsLoaderTest
final File jar1 = new File(extension1, "jar1.jar");
final File jar2 = new File(extension2, "jar2.jar");
Assert.assertTrue(jar1.createNewFile());
Assert.assertTrue(jar2.createNewFile());
createNewJar(jar1, jarFileContents);
createNewJar(jar2, jarFileContents);
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig());
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper);
final ClassLoader classLoader1 = extnLoader.getClassLoaderForExtension(extension1, false);
final ClassLoader classLoader2 = extnLoader.getClassLoaderForExtension(extension2, false);
Assert.assertArrayEquals(new URL[]{jar1.toURI().toURL()}, ((URLClassLoader) classLoader1).getURLs());
Assert.assertArrayEquals(new URL[]{jar2.toURI().toURL()}, ((URLClassLoader) classLoader2).getURLs());
Assert.assertArrayEquals(new URL[]{jar1.toURI().toURL()}, ((StandardURLClassLoader) classLoader1).getURLs());
Assert.assertArrayEquals(new URL[]{jar2.toURI().toURL()}, ((StandardURLClassLoader) classLoader2).getURLs());
}
@Test
public void testGetClassLoaderForExtension_withMissingDependency() throws IOException
{
final ExtensionsLoader extnLoader = new ExtensionsLoader(new ExtensionsConfig(), objectMapper);
final String druidExtensionDependency = "other-druid-extension";
final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(druidExtensionDependency));
final File extensionDir = temporaryFolder.newFolder();
final File extensionJar = new File(extensionDir, "a.jar");
createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies)));
RE exception = Assert.assertThrows(RE.class, () -> {
extnLoader.getClassLoaderForExtension(extensionDir, false);
});
Assert.assertEquals(
StringUtils.format("Extension [%s] depends on [%s] which is not a valid extension or not loaded.", extensionDir.getName(), druidExtensionDependency),
exception.getMessage()
);
}
@Test
public void testGetClassLoaderForExtension_dependencyLoaded() throws IOException
{
ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath());
final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper);
final File extensionDir = temporaryFolder.newFolder();
final File extensionJar = new File(extensionDir, "a.jar");
createNewJar(extensionJar, jarFileContents);
final File dependentExtensionDir = temporaryFolder.newFolder();
final File dependentExtensionJar = new File(dependentExtensionDir, "a.jar");
final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(extensionDir.getName()));
createNewJar(dependentExtensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies)));
StandardURLClassLoader classLoader = extnLoader.getClassLoaderForExtension(extensionDir, false);
StandardURLClassLoader dependendentClassLoader = extnLoader.getClassLoaderForExtension(dependentExtensionDir, false);
Assert.assertTrue(dependendentClassLoader.getExtensionDependencyClassLoaders().contains(classLoader));
Assert.assertEquals(0, classLoader.getExtensionDependencyClassLoaders().size());
}
@Test
public void testGetClassLoaderForExtension_circularDependency() throws IOException
{
ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath());
final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper);
final File extensionDir = temporaryFolder.newFolder();
final File dependentExtensionDir = temporaryFolder.newFolder();
final File extensionJar = new File(extensionDir, "a.jar");
final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of(dependentExtensionDir.getName()));
createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies)));
final File dependentExtensionJar = new File(dependentExtensionDir, "a.jar");
final DruidExtensionDependencies druidExtensionDependenciesCircular = new DruidExtensionDependencies(ImmutableList.of(extensionDir.getName()));
createNewJar(dependentExtensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependenciesCircular)));
RE exception = Assert.assertThrows(RE.class, () -> {
extnLoader.getClassLoaderForExtension(extensionDir, false);
});
Assert.assertTrue(exception.getMessage().contains("has a circular druid extension dependency."));
}
@Test
public void testGetClassLoaderForExtension_multipleDruidJars() throws IOException
{
ExtensionsConfig extensionsConfig = new TestExtensionsConfig(temporaryFolder.getRoot().getPath());
final ExtensionsLoader extnLoader = new ExtensionsLoader(extensionsConfig, objectMapper);
final File extensionDir = temporaryFolder.newFolder();
final File extensionJar = new File(extensionDir, "a.jar");
final DruidExtensionDependencies druidExtensionDependencies = new DruidExtensionDependencies(ImmutableList.of());
createNewJar(extensionJar, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies)));
final File extensionJar2 = new File(extensionDir, "b.jar");
createNewJar(extensionJar2, ImmutableMap.of("druid-extension-dependencies.json", objectMapper.writeValueAsBytes(druidExtensionDependencies)));
RE exception = Assert.assertThrows(RE.class, () -> {
extnLoader.getClassLoaderForExtension(extensionDir, false);
});
Assert.assertTrue(
exception.getMessage().contains("Each jar should be in a separate extension directory.")
);
}
private void createNewJar(File jarFileLocation, Map<String, byte[]> jarFileContents) throws IOException
{
Assert.assertTrue(jarFileLocation.createNewFile());
FileOutputStream fos = new FileOutputStream(jarFileLocation.getPath());
JarOutputStream jarOut = new JarOutputStream(fos);
for (Map.Entry<String, byte[]> fileNameToContents : jarFileContents.entrySet()) {
JarEntry entry = new JarEntry(fileNameToContents.getKey());
jarOut.putNextEntry(entry);
jarOut.write(fileNameToContents.getValue());
jarOut.closeEntry();
}
jarOut.close();
fos.close();
}
private static class TestExtensionsConfig extends ExtensionsConfig
{
final String directory;
public TestExtensionsConfig(String directory)
{
this.directory = directory;
}
@Override
public String getDirectory()
{
return directory;
}
}
}