diff --git a/hibernate-core/src/main/java/org/hibernate/context/internal/ThreadLocalSessionContext.java b/hibernate-core/src/main/java/org/hibernate/context/internal/ThreadLocalSessionContext.java index e6cda3fe7c..c9387a14df 100644 --- a/hibernate-core/src/main/java/org/hibernate/context/internal/ThreadLocalSessionContext.java +++ b/hibernate-core/src/main/java/org/hibernate/context/internal/ThreadLocalSessionContext.java @@ -15,6 +15,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import javax.transaction.Synchronization; @@ -291,17 +292,33 @@ public class ThreadLocalSessionContext extends AbstractCurrentSessionContext { } @Override + @SuppressWarnings("SimplifiableIfStatement") public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - final String methodName = method.getName(); + final String methodName = method.getName(); + + // first check methods calls that we handle completely locally: + if ( "equals".equals( methodName ) && method.getParameterCount() == 1 ) { + if ( args[0] == null + || !Proxy.isProxyClass( args[0].getClass() ) ) { + return false; + } + return this.equals( Proxy.getInvocationHandler( args[0] ) ); + } + else if ( "hashCode".equals( methodName ) && method.getParameterCount() == 0 ) { + return this.hashCode(); + } + else if ( "toString".equals( methodName ) && method.getParameterCount() == 0 ) { + return String.format( Locale.ROOT, "ThreadLocalSessionContext.TransactionProtectionWrapper[%s]", realSession ); + } + + + // then check method calls that we need to delegate to the real Session try { // If close() is called, guarantee unbind() if ( "close".equals( methodName ) ) { unbind( realSession.getSessionFactory() ); } - else if ( "toString".equals( methodName ) - || "equals".equals( methodName ) - || "hashCode".equals( methodName ) - || "getStatistics".equals( methodName ) + else if ( "getStatistics".equals( methodName ) || "isOpen".equals( methodName ) || "getListeners".equals( methodName ) ) { // allow these to go through the the real session no matter what diff --git a/hibernate-core/src/test/java/org/hibernate/test/connections/ThreadLocalCurrentSessionTest.java b/hibernate-core/src/test/java/org/hibernate/test/connections/ThreadLocalCurrentSessionTest.java index 8f939f560f..ac04d3a1cb 100644 --- a/hibernate-core/src/test/java/org/hibernate/test/connections/ThreadLocalCurrentSessionTest.java +++ b/hibernate-core/src/test/java/org/hibernate/test/connections/ThreadLocalCurrentSessionTest.java @@ -17,10 +17,12 @@ import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.resource.transaction.spi.TransactionStatus; import org.hibernate.testing.RequiresDialect; +import org.hibernate.testing.TestForIssue; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -73,12 +75,22 @@ public class ThreadLocalCurrentSessionTest extends ConnectionManagementTestCase assertTrue( "session not bound after deserialize", TestableThreadLocalContext.isSessionBound( session ) ); } + @Test + @TestForIssue(jiraKey = "HHH-11067") + public void testEqualityChecking() { + Session session1 = sessionFactory().getCurrentSession(); + Session session2 = sessionFactory().getCurrentSession(); + + assertSame( "== check", session1, session2 ); + assertEquals( "#equals check", session1, session2 ); + } + @Test public void testTransactionProtection() { Session session = sessionFactory().getCurrentSession(); try { session.createQuery( "from Silly" ); - fail( "method other than beginTransaction{} allowed" ); + fail( "method other than beginTransaction() allowed" ); } catch ( HibernateException e ) { // ok