diff --git a/hibernate-core/src/main/java/org/hibernate/bytecode/internal/bytebuddy/HibernateMethodLookupDispatcher.java b/hibernate-core/src/main/java/org/hibernate/bytecode/internal/bytebuddy/HibernateMethodLookupDispatcher.java index bb3b287f6a..50132a24c1 100644 --- a/hibernate-core/src/main/java/org/hibernate/bytecode/internal/bytebuddy/HibernateMethodLookupDispatcher.java +++ b/hibernate-core/src/main/java/org/hibernate/bytecode/internal/bytebuddy/HibernateMethodLookupDispatcher.java @@ -10,7 +10,6 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; @@ -23,9 +22,12 @@ public class HibernateMethodLookupDispatcher { private static final SecurityActions SECURITY_ACTIONS = new SecurityActions(); private static final Function> STACK_FRAME_GET_DECLARING_CLASS_FUNCTION; + @SuppressWarnings("rawtypes") + private static final Function STACK_FRAME_EXTRACT_FUNCTION; private static Object stackWalker; private static Method stackWalkerWalkMethod; private static Method stackFrameGetDeclaringClass; + private static final PrivilegedAction> GET_CALLER_CLASS_ACTION; // Currently, the bytecode provider is created statically and shared between all the session factories. Thus we // can't clear this set when we close a session factory as we might remove elements coming from another one. @@ -133,49 +135,67 @@ public class HibernateMethodLookupDispatcher { } } }; - } - private static Class getCallerClass() { - PrivilegedAction> getCallerClassAction = new PrivilegedAction>() { + STACK_FRAME_EXTRACT_FUNCTION = new Function() { @Override @SuppressWarnings({ "unchecked", "rawtypes" }) + public Object apply(Stream stream) { + return stream.map( STACK_FRAME_GET_DECLARING_CLASS_FUNCTION ) + .limit( 16 ) + .toArray( Class[]::new ); + } + }; + + GET_CALLER_CLASS_ACTION = new PrivilegedAction>() { + + @Override public Class run() { try { + Class[] stackTrace; if ( stackWalker != null ) { - Optional> clazzOptional = (Optional>) stackWalkerWalkMethod.invoke( stackWalker, new Function() { - @Override - public Object apply(Stream stream) { - return stream.map( STACK_FRAME_GET_DECLARING_CLASS_FUNCTION ) - .skip( System.getSecurityManager() != null ? 6 : 5 ) - .findFirst(); - } - }); - - if ( !clazzOptional.isPresent() ) { - throw new HibernateException( "Unable to determine the caller class" ); - } - - return clazzOptional.get(); + stackTrace = (Class[]) stackWalkerWalkMethod.invoke( stackWalker, STACK_FRAME_EXTRACT_FUNCTION ); } else { - return SECURITY_ACTIONS.getCallerClass(); + stackTrace = SECURITY_ACTIONS.getCallerClass(); } + + // this shouldn't happen but let's be safe + if ( stackTrace.length < 4 ) { + throw new SecurityException( "Unable to determine the caller class" ); + } + + boolean hibernateMethodLookupDispatcherDetected = false; + // start at the 4th frame and limit that to the 16 first frames + int maxFrames = Math.min( 16, stackTrace.length ); + for ( int i = 3; i < maxFrames; i++ ) { + if ( stackTrace[i].getName().equals( HibernateMethodLookupDispatcher.class.getName() ) ) { + hibernateMethodLookupDispatcherDetected = true; + continue; + } + if ( hibernateMethodLookupDispatcherDetected ) { + return stackTrace[i]; + } + } + + throw new SecurityException( "Unable to determine the caller class" ); } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { throw new SecurityException( "Unable to determine the caller class", e ); } } }; + } - return System.getSecurityManager() != null ? AccessController.doPrivileged( getCallerClassAction ) : - getCallerClassAction.run(); + private static Class getCallerClass() { + return System.getSecurityManager() != null ? AccessController.doPrivileged( GET_CALLER_CLASS_ACTION ) : + GET_CALLER_CLASS_ACTION.run(); } private static class SecurityActions extends SecurityManager { - private Class getCallerClass() { - return getClassContext()[7]; + private Class[] getCallerClass() { + return getClassContext(); } } }