HHH-13277 Simplify HibernateMethodLookupDispatcher

This commit is contained in:
Yoann Rodière 2019-02-22 12:48:16 +01:00 committed by Guillaume Smet
parent 0b3babe4fb
commit 38a0cd2690
1 changed files with 101 additions and 85 deletions

View File

@ -19,15 +19,17 @@ import org.hibernate.HibernateException;
public class HibernateMethodLookupDispatcher { public class HibernateMethodLookupDispatcher {
private static final SecurityActions SECURITY_ACTIONS = new SecurityActions(); /**
* The minimum number of stack frames to drop before we can hope to find the caller frame.
private static final Function<Object, Class<?>> STACK_FRAME_GET_DECLARING_CLASS_FUNCTION; */
@SuppressWarnings("rawtypes") private static final int MIN_STACK_FRAMES = 3;
private static final Function<Stream, Object> STACK_FRAME_EXTRACT_FUNCTION; /**
private static Object stackWalker; * The maximum number of stack frames to explore to find the caller frame.
private static Method stackWalkerWalkMethod; * <p>
private static Method stackFrameGetDeclaringClass; * Beyond that, we give up and throw an exception.
private static final PrivilegedAction<Class<?>> GET_CALLER_CLASS_ACTION; */
private static final int MAX_STACK_FRAMES = 16;
private static final PrivilegedAction<Class<?>[]> GET_CALLER_STACK_ACTION;
// Currently, the bytecode provider is created statically and shared between all the session factories. Thus we // Currently, the bytecode provider is created statically and shared between all the session factories. Thus we
// can't clear this set when we close a session factory as we might remove elements coming from another one. // can't clear this set when we close a session factory as we might remove elements coming from another one.
@ -85,10 +87,10 @@ public class HibernateMethodLookupDispatcher {
} }
static { static {
PrivilegedAction<Void> initializeGetCallerClassRequirementsAction = new PrivilegedAction<Void>() { // The action below will return the action used at runtime to retrieve the caller stack
PrivilegedAction<PrivilegedAction<Class<?>[]>> initializeGetCallerStackAction = new PrivilegedAction<PrivilegedAction<Class<?>[]>>() {
@Override @Override
public Void run() { public PrivilegedAction<Class<?>[]> run() {
Class<?> stackWalkerClass = null; Class<?> stackWalkerClass = null;
try { try {
stackWalkerClass = Class.forName( "java.lang.StackWalker" ); stackWalkerClass = Class.forName( "java.lang.StackWalker" );
@ -98,77 +100,50 @@ public class HibernateMethodLookupDispatcher {
} }
if ( stackWalkerClass != null ) { if ( stackWalkerClass != null ) {
// We can use a stack walker
try { try {
Class<?> optionClass = Class.forName( "java.lang.StackWalker$Option" ); Class<?> optionClass = Class.forName( "java.lang.StackWalker$Option" );
stackWalker = stackWalkerClass.getMethod( "getInstance", optionClass ) Object stackWalker = stackWalkerClass.getMethod( "getInstance", optionClass )
// The first one is RETAIN_CLASS_REFERENCE // The first one is RETAIN_CLASS_REFERENCE
.invoke( null, optionClass.getEnumConstants()[0] ); .invoke( null, optionClass.getEnumConstants()[0] );
stackWalkerWalkMethod = stackWalkerClass.getMethod( "walk", Function.class ); Method stackWalkerWalkMethod = stackWalkerClass.getMethod( "walk", Function.class );
stackFrameGetDeclaringClass = Class.forName( "java.lang.StackWalker$StackFrame" ) Method stackFrameGetDeclaringClass = Class.forName( "java.lang.StackWalker$StackFrame" )
.getMethod( "getDeclaringClass" ); .getMethod( "getDeclaringClass" );
return new StackWalkerGetCallerStackAction(
stackWalker, stackWalkerWalkMethod,stackFrameGetDeclaringClass
);
} }
catch (Throwable e) { catch (Throwable e) {
throw new HibernateException( "Unable to initialize the stack walker", e ); throw new HibernateException( "Unable to initialize the stack walker", e );
} }
} }
return null;
}
};
if ( System.getSecurityManager() != null ) {
AccessController.doPrivileged( initializeGetCallerClassRequirementsAction );
}
else { else {
initializeGetCallerClassRequirementsAction.run(); // We cannot use a stack walker, default to fetching the security manager class context
} return new SecurityManagerClassContextGetCallerStackAction();
STACK_FRAME_GET_DECLARING_CLASS_FUNCTION = new Function<Object, Class<?>>() {
@Override
public Class<?> apply(Object t) {
try {
return (Class<?>) stackFrameGetDeclaringClass.invoke( t );
}
catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new HibernateException( "Unable to get stack frame declaring class", e );
} }
} }
}; };
STACK_FRAME_EXTRACT_FUNCTION = new Function<Stream, Object>() { GET_CALLER_STACK_ACTION = System.getSecurityManager() != null
? AccessController.doPrivileged( initializeGetCallerStackAction )
@Override : initializeGetCallerStackAction.run();
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object apply(Stream stream) {
return stream.map( STACK_FRAME_GET_DECLARING_CLASS_FUNCTION )
.limit( 16 )
.toArray( Class<?>[]::new );
} }
};
GET_CALLER_CLASS_ACTION = new PrivilegedAction<Class<?>>() { private static Class<?> getCallerClass() {
Class<?>[] stackTrace = System.getSecurityManager() != null
@Override ? AccessController.doPrivileged( GET_CALLER_STACK_ACTION )
public Class<?> run() { : GET_CALLER_STACK_ACTION.run();
try {
Class<?>[] stackTrace;
if ( stackWalker != null ) {
stackTrace = (Class<?>[]) stackWalkerWalkMethod.invoke( stackWalker, STACK_FRAME_EXTRACT_FUNCTION );
}
else {
stackTrace = SECURITY_ACTIONS.getCallerClass();
}
// this shouldn't happen but let's be safe // this shouldn't happen but let's be safe
if ( stackTrace.length < 4 ) { if ( stackTrace.length <= MIN_STACK_FRAMES ) {
throw new SecurityException( "Unable to determine the caller class" ); throw new SecurityException( "Unable to determine the caller class" );
} }
boolean hibernateMethodLookupDispatcherDetected = false; boolean hibernateMethodLookupDispatcherDetected = false;
// start at the 4th frame and limit that to the 16 first frames // start at the 4th frame and limit that to the MAX_STACK_FRAMES first frames
int maxFrames = Math.min( 16, stackTrace.length ); int maxFrames = Math.min( MAX_STACK_FRAMES, stackTrace.length );
for ( int i = 3; i < maxFrames; i++ ) { for ( int i = MIN_STACK_FRAMES; i < maxFrames; i++ ) {
if ( stackTrace[i].getName().equals( HibernateMethodLookupDispatcher.class.getName() ) ) { if ( stackTrace[i].getName().equals( HibernateMethodLookupDispatcher.class.getName() ) ) {
hibernateMethodLookupDispatcherDetected = true; hibernateMethodLookupDispatcherDetected = true;
continue; continue;
@ -180,22 +155,63 @@ public class HibernateMethodLookupDispatcher {
throw new SecurityException( "Unable to determine the caller class" ); throw new SecurityException( "Unable to determine the caller class" );
} }
/**
* A privileged action that retrieves the caller stack from the security manager class context.
*/
private static class SecurityManagerClassContextGetCallerStackAction extends SecurityManager
implements PrivilegedAction<Class<?>[]> {
@Override
public Class<?>[] run() {
return getClassContext();
}
}
/**
* A privileged action that retrieves the caller stack using a stack walker.
*/
private static class StackWalkerGetCallerStackAction implements PrivilegedAction<Class<?>[]> {
private final Object stackWalker;
private final Method stackWalkerWalkMethod;
private final Method stackFrameGetDeclaringClass;
StackWalkerGetCallerStackAction(Object stackWalker, Method stackWalkerWalkMethod,
Method stackFrameGetDeclaringClass) {
this.stackWalker = stackWalker;
this.stackWalkerWalkMethod = stackWalkerWalkMethod;
this.stackFrameGetDeclaringClass = stackFrameGetDeclaringClass;
}
@Override
public Class<?>[] run() {
try {
return (Class<?>[]) stackWalkerWalkMethod.invoke( stackWalker, stackFrameExtractFunction );
}
catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new SecurityException( "Unable to determine the caller class", e ); throw new SecurityException( "Unable to determine the caller class", e );
} }
} }
private final Function<Stream, Object> stackFrameExtractFunction = new Function<Stream, Object>() {
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object apply(Stream stream) {
return stream.map( stackFrameGetDeclaringClassFunction )
.limit( MAX_STACK_FRAMES )
.toArray( Class<?>[]::new );
}
};
private final Function<Object, Class<?>> stackFrameGetDeclaringClassFunction = new Function<Object, Class<?>>() {
@Override
public Class<?> apply(Object t) {
try {
return (Class<?>) stackFrameGetDeclaringClass.invoke( t );
}
catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
throw new HibernateException( "Unable to get stack frame declaring class", e );
}
}
}; };
} }
private static Class<?> getCallerClass() {
return System.getSecurityManager() != null ? AccessController.doPrivileged( GET_CALLER_CLASS_ACTION ) :
GET_CALLER_CLASS_ACTION.run();
}
private static class SecurityActions extends SecurityManager {
private Class<?>[] getCallerClass() {
return getClassContext();
}
}
} }