mirror of https://github.com/apache/openjpa.git
extracting directory enhancement logic in a runnable for junit5 module to be able to reuse it easily
This commit is contained in:
parent
4c9ac41fd1
commit
beb125500f
|
@ -0,0 +1,360 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.apache.openjpa.junit5.internal;
|
||||
|
||||
import org.apache.openjpa.conf.OpenJPAConfigurationImpl;
|
||||
import org.apache.openjpa.enhance.AsmAdaptor;
|
||||
import org.apache.openjpa.enhance.PCEnhancer;
|
||||
import org.apache.openjpa.enhance.PersistenceCapable;
|
||||
import org.apache.openjpa.lib.log.LogFactory;
|
||||
import org.apache.openjpa.lib.log.LogFactoryImpl;
|
||||
import org.apache.openjpa.lib.log.SLF4JLogFactory;
|
||||
import org.apache.openjpa.meta.MetaDataRepository;
|
||||
import org.apache.openjpa.persistence.PersistenceMetaDataFactory;
|
||||
import org.apache.xbean.asm7.AnnotationVisitor;
|
||||
import org.apache.xbean.asm7.ClassReader;
|
||||
import org.apache.xbean.asm7.Type;
|
||||
import org.apache.xbean.asm7.shade.commons.EmptyVisitor;
|
||||
import org.apache.xbean.finder.ClassLoaders;
|
||||
import serp.bytecode.BCClass;
|
||||
import serp.bytecode.Project;
|
||||
|
||||
import javax.persistence.Embeddable;
|
||||
import javax.persistence.Entity;
|
||||
import javax.persistence.MappedSuperclass;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.nio.file.FileVisitResult;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.SimpleFileVisitor;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.nio.file.attribute.BasicFileAttributes;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_CODE;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_DEBUG;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_FRAMES;
|
||||
|
||||
public class OpenJPADirectoriesEnhancer implements Runnable {
|
||||
private static final Logger LOGGER = Logger.getLogger(OpenJPADirectoriesEnhancer.class.getName());
|
||||
public static final StackTraceElement[] NO_STACK_TRACE = new StackTraceElement[0];
|
||||
|
||||
private final boolean auto;
|
||||
private final String[] entities;
|
||||
private final Class<?> logFactory;
|
||||
|
||||
public OpenJPADirectoriesEnhancer(final boolean auto, final String[] entities, final Class<?> logFactory) {
|
||||
this.auto = auto;
|
||||
this.entities = entities;
|
||||
this.logFactory = logFactory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
|
||||
final OpenJpaClassLoader enhancementClassLoader = new OpenJpaClassLoader(
|
||||
classLoader, createLogFactory(classLoader));
|
||||
final Thread thread = Thread.currentThread();
|
||||
thread.setContextClassLoader(enhancementClassLoader);
|
||||
try {
|
||||
if (auto) {
|
||||
try {
|
||||
ClassLoaders.findUrls(enhancementClassLoader.getParent()).stream()
|
||||
.map(org.apache.xbean.finder.util.Files::toFile)
|
||||
.filter(File::isDirectory)
|
||||
.map(File::toPath)
|
||||
.forEach(dir -> {
|
||||
LOGGER.fine(() -> "Enhancing folder '" + dir + "'");
|
||||
try {
|
||||
enhanceDirectory(enhancementClassLoader, dir);
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
});
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
} else {
|
||||
Stream.of(entities).forEach(e -> {
|
||||
try {
|
||||
enhancementClassLoader.loadClass(e);
|
||||
} catch (final ClassNotFoundException e1) {
|
||||
throw new IllegalArgumentException(e1);
|
||||
}
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
thread.setContextClassLoader(enhancementClassLoader.getParent());
|
||||
}
|
||||
}
|
||||
|
||||
private LogFactory createLogFactory(final ClassLoader classLoader) {
|
||||
try {
|
||||
if (logFactory == null || logFactory == LogFactory.class) {
|
||||
try {
|
||||
return new SLF4JLogFactory();
|
||||
} catch (final Error | Exception e) {
|
||||
return new LogFactoryImpl();
|
||||
}
|
||||
}
|
||||
return logFactory.asSubclass(LogFactory.class).getConstructor().newInstance();
|
||||
} catch (final RuntimeException e) {
|
||||
throw e;
|
||||
} catch (final Exception e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void enhanceDirectory(final OpenJpaClassLoader enhancementClassLoader, final Path dir) throws IOException {
|
||||
Files.walkFileTree(dir, new SimpleFileVisitor<Path>() {
|
||||
@Override
|
||||
public FileVisitResult visitFile(final Path file, final BasicFileAttributes attrs) throws IOException {
|
||||
if (file.getFileName().toString().endsWith(".class")) {
|
||||
final String relativeName = dir.relativize(file).toString();
|
||||
try {
|
||||
enhancementClassLoader.handleEnhancement(
|
||||
relativeName.substring(0, relativeName.length() - ".class".length()));
|
||||
} catch (final ClassNotFoundException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
return super.visitFile(file, attrs);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private static abstract class BaseClassLoader extends ClassLoader {
|
||||
private BaseClassLoader(final ClassLoader parent) {
|
||||
super(parent);
|
||||
}
|
||||
|
||||
protected abstract Class<?> doLoadClass(String name, boolean resolve) throws ClassNotFoundException;
|
||||
|
||||
@Override
|
||||
protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
if (name != null && !name.startsWith("java") && !name.startsWith("sun") && !name.startsWith("jdk")) {
|
||||
return doLoadClass(name, resolve);
|
||||
}
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
|
||||
protected Class<?> defaultLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
return super.loadClass(name, resolve);
|
||||
}
|
||||
|
||||
protected byte[] loadBytes(final String name) {
|
||||
final URL url = findUrl(name);
|
||||
if (url == null || "jar".equals(url.getProtocol()) /*assume done in build*/) {
|
||||
return null;
|
||||
}
|
||||
byte[] buffer = new byte[4096];
|
||||
final ByteArrayOutputStream inMem = new ByteArrayOutputStream(buffer.length);
|
||||
try (final InputStream is = url.openStream()) {
|
||||
int read;
|
||||
while ((read = is.read(buffer)) >= 0) {
|
||||
if (read > 0) {
|
||||
inMem.write(buffer, 0, read);
|
||||
}
|
||||
}
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
return inMem.toByteArray();
|
||||
}
|
||||
|
||||
protected URL findUrl(final String name) {
|
||||
return getResource(name.replace('.', '/') + ".class");
|
||||
}
|
||||
}
|
||||
|
||||
private static class OpenJpaClassLoader extends BaseClassLoader {
|
||||
private static final String PERSITENCE_CAPABLE = Type.getDescriptor(PersistenceCapable.class);
|
||||
private static final String ENTITY = Type.getDescriptor(Entity.class);
|
||||
private static final String EMBEDDABLE = Type.getDescriptor(Embeddable.class);
|
||||
private static final String MAPPED_SUPERCLASS = Type.getDescriptor(MappedSuperclass.class);
|
||||
|
||||
private final MetaDataRepository repos;
|
||||
private final ClassLoader tmpLoader;
|
||||
private final Collection<String> alreadyEnhanced = new ArrayList<>();
|
||||
|
||||
private OpenJpaClassLoader(final ClassLoader parent, final LogFactory logFactory) {
|
||||
super(parent);
|
||||
|
||||
final OpenJPAConfigurationImpl conf = new OpenJPAConfigurationImpl();
|
||||
conf.setLogFactory(logFactory);
|
||||
|
||||
tmpLoader = new CompanionLoader(parent);
|
||||
repos = new MetaDataRepository();
|
||||
repos.setConfiguration(conf);
|
||||
repos.setMetaDataFactory(new PersistenceMetaDataFactory());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected synchronized Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
final Class<?> clazz = findLoadedClass(name);
|
||||
if (clazz != null) {
|
||||
if (resolve) {
|
||||
resolveClass(clazz);
|
||||
}
|
||||
return clazz;
|
||||
}
|
||||
handleEnhancement(name);
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
|
||||
private void handleEnhancement(final String name) throws ClassNotFoundException {
|
||||
final byte[] enhanced = ensureEnhancedIfNeeded(name);
|
||||
if (enhanced != null && alreadyEnhanced.add(name)) {
|
||||
// we could do that but test classes will be loaded with parent loader
|
||||
// so just rewrite the class on the fly assuming it was not yet read
|
||||
try {
|
||||
Files.write(findTarget(name), enhanced, StandardOpenOption.TRUNCATE_EXISTING);
|
||||
LOGGER.info(() -> "Enhanced '" + name + "'");
|
||||
} catch (final IOException e) {
|
||||
throw new ClassNotFoundException(e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Path findTarget(final String name) {
|
||||
final URL url = findUrl(name);
|
||||
if (!"file".equals(url.getProtocol())) {
|
||||
throw new IllegalStateException("Only file urls are supported today: " + url);
|
||||
}
|
||||
return Paths.get(url.getPath());
|
||||
}
|
||||
|
||||
private byte[] enhance(final byte[] classBytes) {
|
||||
final Thread thread = Thread.currentThread();
|
||||
final ClassLoader old = thread.getContextClassLoader();
|
||||
thread.setContextClassLoader(tmpLoader);
|
||||
try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
|
||||
final PCEnhancer enhancer = new PCEnhancer(
|
||||
repos.getConfiguration(),
|
||||
new Project().loadClass(stream, tmpLoader),
|
||||
repos, tmpLoader);
|
||||
if (enhancer.run() == PCEnhancer.ENHANCE_NONE) {
|
||||
return null;
|
||||
}
|
||||
final BCClass pcb = enhancer.getPCBytecode();
|
||||
return AsmAdaptor.toByteArray(pcb, pcb.toByteArray());
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
} finally {
|
||||
thread.setContextClassLoader(old);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isJpaButNotEnhanced(final byte[] classBytes) {
|
||||
try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
|
||||
final ClassReader reader = new ClassReader(stream);
|
||||
reader.accept(new EmptyVisitor() {
|
||||
@Override
|
||||
public void visit(final int version, final int access, final String name,
|
||||
final String signature, final String superName, final String[] interfaces) {
|
||||
if (interfaces != null && asList(interfaces).contains(PERSITENCE_CAPABLE)) {
|
||||
throw new AlreadyEnhanced(); // exit
|
||||
}
|
||||
super.visit(version, access, name, signature, superName, interfaces);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) {
|
||||
if (ENTITY.equals(descriptor) ||
|
||||
EMBEDDABLE.equals(descriptor) ||
|
||||
MAPPED_SUPERCLASS.equals(descriptor)) {
|
||||
throw new MissingEnhancement(); // we already went into visit() so we miss the enhancement
|
||||
}
|
||||
return new EmptyVisitor().visitAnnotation(descriptor, visible);
|
||||
}
|
||||
}, SKIP_DEBUG + SKIP_CODE + SKIP_FRAMES);
|
||||
return false;
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
} catch (final AlreadyEnhanced alreadyEnhanced) {
|
||||
return false;
|
||||
} catch (final MissingEnhancement alreadyEnhanced) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] ensureEnhancedIfNeeded(final String name) {
|
||||
final byte[] classBytes = loadBytes(name);
|
||||
if (classBytes == null) {
|
||||
return null;
|
||||
}
|
||||
if (isJpaButNotEnhanced(classBytes)) {
|
||||
final byte[] enhanced = enhance(classBytes);
|
||||
if (enhanced != null) {
|
||||
return enhanced;
|
||||
}
|
||||
LOGGER.info("'" + name + "' already enhanced");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CompanionLoader extends BaseClassLoader {
|
||||
private CompanionLoader(final ClassLoader parent) {
|
||||
super(parent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
final Class<?> clazz = findLoadedClass(name);
|
||||
if (clazz != null) {
|
||||
if (resolve) {
|
||||
resolveClass(clazz);
|
||||
}
|
||||
return clazz;
|
||||
}
|
||||
final byte[] content = loadBytes(name);
|
||||
if (content != null) {
|
||||
final Class<?> value = super.defineClass(name, content, 0, content.length);
|
||||
if (resolve) {
|
||||
resolveClass(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
}
|
||||
|
||||
private static class MissingEnhancement extends RuntimeException {
|
||||
private MissingEnhancement() {
|
||||
setStackTrace(NO_STACK_TRACE);
|
||||
}
|
||||
}
|
||||
|
||||
private static class AlreadyEnhanced extends RuntimeException {
|
||||
private AlreadyEnhanced() {
|
||||
setStackTrace(NO_STACK_TRACE);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,333 +18,19 @@
|
|||
*/
|
||||
package org.apache.openjpa.junit5.internal;
|
||||
|
||||
import org.apache.openjpa.conf.OpenJPAConfigurationImpl;
|
||||
import org.apache.openjpa.enhance.AsmAdaptor;
|
||||
import org.apache.openjpa.enhance.PCEnhancer;
|
||||
import org.apache.openjpa.enhance.PersistenceCapable;
|
||||
import org.apache.openjpa.junit5.OpenJPASupport;
|
||||
import org.apache.openjpa.lib.log.LogFactory;
|
||||
import org.apache.openjpa.lib.log.LogFactoryImpl;
|
||||
import org.apache.openjpa.lib.log.SLF4JLogFactory;
|
||||
import org.apache.openjpa.meta.MetaDataRepository;
|
||||
import org.apache.openjpa.persistence.PersistenceMetaDataFactory;
|
||||
import org.apache.xbean.asm7.AnnotationVisitor;
|
||||
import org.apache.xbean.asm7.ClassReader;
|
||||
import org.apache.xbean.asm7.Type;
|
||||
import org.apache.xbean.asm7.shade.commons.EmptyVisitor;
|
||||
import org.apache.xbean.finder.ClassLoaders;
|
||||
import org.junit.jupiter.api.extension.BeforeAllCallback;
|
||||
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||
import org.junit.platform.commons.util.AnnotationUtils;
|
||||
import serp.bytecode.BCClass;
|
||||
import serp.bytecode.Project;
|
||||
|
||||
import javax.persistence.Embeddable;
|
||||
import javax.persistence.Entity;
|
||||
import javax.persistence.MappedSuperclass;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.net.URL;
|
||||
import java.nio.file.FileVisitResult;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.SimpleFileVisitor;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.nio.file.attribute.BasicFileAttributes;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_CODE;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_DEBUG;
|
||||
import static org.apache.xbean.asm7.ClassReader.SKIP_FRAMES;
|
||||
|
||||
public class OpenJPAExtension implements BeforeAllCallback {
|
||||
private static final Logger LOGGER = Logger.getLogger(OpenJPAExtension.class.getName());
|
||||
|
||||
@Override
|
||||
public void beforeAll(final ExtensionContext context) {
|
||||
AnnotationUtils.findAnnotation(context.getElement(), OpenJPASupport.class).ifPresent(s -> {
|
||||
final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
|
||||
final OpenJpaClassLoader enhancementClassLoader = new OpenJpaClassLoader(
|
||||
classLoader, createLogFactory(classLoader, s.logFactory()));
|
||||
final Thread thread = Thread.currentThread();
|
||||
thread.setContextClassLoader(enhancementClassLoader);
|
||||
try {
|
||||
if (s.auto()) {
|
||||
try {
|
||||
ClassLoaders.findUrls(enhancementClassLoader.getParent()).stream()
|
||||
.map(org.apache.xbean.finder.util.Files::toFile)
|
||||
.filter(File::isDirectory)
|
||||
.map(File::toPath)
|
||||
.forEach(dir -> {
|
||||
LOGGER.fine(() -> "Enhancing folder '" + dir + "'");
|
||||
try {
|
||||
enhanceDirectory(enhancementClassLoader, dir);
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
});
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
} else {
|
||||
Stream.of(s.entities()).forEach(e -> {
|
||||
try {
|
||||
enhancementClassLoader.loadClass(e);
|
||||
} catch (final ClassNotFoundException e1) {
|
||||
throw new IllegalArgumentException(e1);
|
||||
}
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
thread.setContextClassLoader(enhancementClassLoader.getParent());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private LogFactory createLogFactory(final ClassLoader classLoader, final Class<?> logFactory) {
|
||||
try {
|
||||
if (logFactory == LogFactory.class) {
|
||||
try {
|
||||
return new SLF4JLogFactory();
|
||||
} catch (final Error | Exception e) {
|
||||
return new LogFactoryImpl();
|
||||
}
|
||||
}
|
||||
return logFactory.asSubclass(LogFactory.class).getConstructor().newInstance();
|
||||
} catch (final RuntimeException e) {
|
||||
throw e;
|
||||
} catch (final Exception e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void enhanceDirectory(final OpenJpaClassLoader enhancementClassLoader, final Path dir) throws IOException {
|
||||
Files.walkFileTree(dir, new SimpleFileVisitor<Path>() {
|
||||
@Override
|
||||
public FileVisitResult visitFile(final Path file, final BasicFileAttributes attrs) throws IOException {
|
||||
if (file.getFileName().toString().endsWith(".class")) {
|
||||
final String relativeName = dir.relativize(file).toString();
|
||||
try {
|
||||
enhancementClassLoader.handleEnhancement(
|
||||
relativeName.substring(0, relativeName.length() - ".class".length()));
|
||||
} catch (final ClassNotFoundException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
return super.visitFile(file, attrs);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private static abstract class BaseClassLoader extends ClassLoader {
|
||||
private BaseClassLoader(final ClassLoader parent) {
|
||||
super(parent);
|
||||
}
|
||||
|
||||
protected abstract Class<?> doLoadClass(String name, boolean resolve) throws ClassNotFoundException;
|
||||
|
||||
@Override
|
||||
protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
if (name != null && !name.startsWith("java") && !name.startsWith("sun") && !name.startsWith("jdk")) {
|
||||
return doLoadClass(name, resolve);
|
||||
}
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
|
||||
protected Class<?> defaultLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
return super.loadClass(name, resolve);
|
||||
}
|
||||
|
||||
protected byte[] loadBytes(final String name) {
|
||||
final URL url = findUrl(name);
|
||||
if (url == null || "jar".equals(url.getProtocol()) /*assume done in build*/) {
|
||||
return null;
|
||||
}
|
||||
byte[] buffer = new byte[4096];
|
||||
final ByteArrayOutputStream inMem = new ByteArrayOutputStream(buffer.length);
|
||||
try (final InputStream is = url.openStream()) {
|
||||
int read;
|
||||
while ((read = is.read(buffer)) >= 0) {
|
||||
if (read > 0) {
|
||||
inMem.write(buffer, 0, read);
|
||||
}
|
||||
}
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
return inMem.toByteArray();
|
||||
}
|
||||
|
||||
protected URL findUrl(final String name) {
|
||||
return getResource(name.replace('.', '/') + ".class");
|
||||
}
|
||||
}
|
||||
|
||||
private static class OpenJpaClassLoader extends BaseClassLoader {
|
||||
private static final String PERSITENCE_CAPABLE = Type.getDescriptor(PersistenceCapable.class);
|
||||
private static final String ENTITY = Type.getDescriptor(Entity.class);
|
||||
private static final String EMBEDDABLE = Type.getDescriptor(Embeddable.class);
|
||||
private static final String MAPPED_SUPERCLASS = Type.getDescriptor(MappedSuperclass.class);
|
||||
|
||||
private final MetaDataRepository repos;
|
||||
private final ClassLoader tmpLoader;
|
||||
private final Collection<String> alreadyEnhanced = new ArrayList<>();
|
||||
|
||||
private OpenJpaClassLoader(final ClassLoader parent, final LogFactory logFactory) {
|
||||
super(parent);
|
||||
|
||||
final OpenJPAConfigurationImpl conf = new OpenJPAConfigurationImpl();
|
||||
conf.setLogFactory(logFactory);
|
||||
|
||||
tmpLoader = new CompanionLoader(parent);
|
||||
repos = new MetaDataRepository();
|
||||
repos.setConfiguration(conf);
|
||||
repos.setMetaDataFactory(new PersistenceMetaDataFactory());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected synchronized Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
final Class<?> clazz = findLoadedClass(name);
|
||||
if (clazz != null) {
|
||||
if (resolve) {
|
||||
resolveClass(clazz);
|
||||
}
|
||||
return clazz;
|
||||
}
|
||||
handleEnhancement(name);
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
|
||||
private void handleEnhancement(final String name) throws ClassNotFoundException {
|
||||
final byte[] enhanced = ensureEnhancedIfNeeded(name);
|
||||
if (enhanced != null && alreadyEnhanced.add(name)) {
|
||||
// we could do that but test classes will be loaded with parent loader
|
||||
// so just rewrite the class on the fly assuming it was not yet read
|
||||
try {
|
||||
Files.write(findTarget(name), enhanced, StandardOpenOption.TRUNCATE_EXISTING);
|
||||
LOGGER.info(() -> "Enhanced '" + name + "'");
|
||||
} catch (final IOException e) {
|
||||
throw new ClassNotFoundException(e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Path findTarget(final String name) {
|
||||
final URL url = findUrl(name);
|
||||
if (!"file".equals(url.getProtocol())) {
|
||||
throw new IllegalStateException("Only file urls are supported today: " + url);
|
||||
}
|
||||
return Paths.get(url.getPath());
|
||||
}
|
||||
|
||||
private byte[] enhance(final byte[] classBytes) {
|
||||
final Thread thread = Thread.currentThread();
|
||||
final ClassLoader old = thread.getContextClassLoader();
|
||||
thread.setContextClassLoader(tmpLoader);
|
||||
try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
|
||||
final PCEnhancer enhancer = new PCEnhancer(
|
||||
repos.getConfiguration(),
|
||||
new Project().loadClass(stream, tmpLoader),
|
||||
repos, tmpLoader);
|
||||
if (enhancer.run() == PCEnhancer.ENHANCE_NONE) {
|
||||
return null;
|
||||
}
|
||||
final BCClass pcb = enhancer.getPCBytecode();
|
||||
return AsmAdaptor.toByteArray(pcb, pcb.toByteArray());
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
} finally {
|
||||
thread.setContextClassLoader(old);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isJpaButNotEnhanced(final byte[] classBytes) {
|
||||
try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
|
||||
final ClassReader reader = new ClassReader(stream);
|
||||
reader.accept(new EmptyVisitor() {
|
||||
@Override
|
||||
public void visit(final int version, final int access, final String name,
|
||||
final String signature, final String superName, final String[] interfaces) {
|
||||
if (interfaces != null && asList(interfaces).contains(PERSITENCE_CAPABLE)) {
|
||||
throw new AlreadyEnhanced(); // exit
|
||||
}
|
||||
super.visit(version, access, name, signature, superName, interfaces);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) {
|
||||
if (ENTITY.equals(descriptor) ||
|
||||
EMBEDDABLE.equals(descriptor) ||
|
||||
MAPPED_SUPERCLASS.equals(descriptor)) {
|
||||
throw new MissingEnhancement(); // we already went into visit() so we miss the enhancement
|
||||
}
|
||||
return new EmptyVisitor().visitAnnotation(descriptor, visible);
|
||||
}
|
||||
}, SKIP_DEBUG + SKIP_CODE + SKIP_FRAMES);
|
||||
return false;
|
||||
} catch (final IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
} catch (final AlreadyEnhanced alreadyEnhanced) {
|
||||
return false;
|
||||
} catch (final MissingEnhancement alreadyEnhanced) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] ensureEnhancedIfNeeded(final String name) {
|
||||
final byte[] classBytes = loadBytes(name);
|
||||
if (classBytes == null) {
|
||||
return null;
|
||||
}
|
||||
if (isJpaButNotEnhanced(classBytes)) {
|
||||
final byte[] enhanced = enhance(classBytes);
|
||||
if (enhanced != null) {
|
||||
return enhanced;
|
||||
}
|
||||
LOGGER.info("'" + name + "' already enhanced");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static class CompanionLoader extends BaseClassLoader {
|
||||
private CompanionLoader(final ClassLoader parent) {
|
||||
super(parent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
|
||||
final Class<?> clazz = findLoadedClass(name);
|
||||
if (clazz != null) {
|
||||
if (resolve) {
|
||||
resolveClass(clazz);
|
||||
}
|
||||
return clazz;
|
||||
}
|
||||
final byte[] content = loadBytes(name);
|
||||
if (content != null) {
|
||||
final Class<?> value = super.defineClass(name, content, 0, content.length);
|
||||
if (resolve) {
|
||||
resolveClass(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
return defaultLoadClass(name, resolve);
|
||||
}
|
||||
}
|
||||
|
||||
private static class MissingEnhancement extends RuntimeException {
|
||||
}
|
||||
|
||||
private static class AlreadyEnhanced extends RuntimeException {
|
||||
AnnotationUtils.findAnnotation(context.getElement(), OpenJPASupport.class)
|
||||
.ifPresent(s -> new OpenJPADirectoriesEnhancer(s.auto(), s.entities(), s.logFactory()).run());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue