Replace mockito for some tests with custom spies

This commit is contained in:
Christian Beikov 2023-02-13 16:52:04 +01:00
parent edd1c7b7ae
commit 404698b004
20 changed files with 803 additions and 405 deletions

View File

@ -9,28 +9,21 @@ package org.hibernate.test.agroal;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.cfg.Configuration;
import org.hibernate.test.agroal.util.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.DialectChecks;
import org.hibernate.testing.RequiresDialectFeature;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.junit.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import static org.hibernate.testing.transaction.TransactionUtil.doInHibernate;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@RequiresDialectFeature(DialectChecks.SupportsJdbcDriverProxying.class)
public class AgroalSkipAutoCommitTest extends BaseCoreFunctionalTestCase {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider();
@ -74,12 +67,15 @@ public class AgroalSkipAutoCommitTest extends BaseCoreFunctionalTestCase {
List<Connection> connections = connectionProvider.getReleasedConnections();
assertEquals( 1, connections.size() );
Connection connection = connections.get( 0 );
try {
verify(connection, never()).setAutoCommit( false );
List<Object[]> setAutoCommitCalls = connectionProvider.spyContext.getCalls(
Connection.class.getMethod( "setAutoCommit", boolean.class ),
connections.get( 0 )
);
assertTrue( "setAutoCommit should never be called", setAutoCommitCalls.isEmpty() );
}
catch (SQLException e) {
fail(e.getMessage());
catch (NoSuchMethodException e) {
throw new RuntimeException( e );
}
}

View File

@ -9,18 +9,13 @@ package org.hibernate.test.agroal.util;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.hibernate.agroal.internal.AgroalConnectionProvider;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
import org.hibernate.testing.jdbc.JdbcSpies;
/**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
@ -29,8 +24,7 @@ import org.mockito.internal.util.MockUtil;
* @author Vlad Mihalcea
*/
public class PreparedStatementSpyConnectionProvider extends AgroalConnectionProvider {
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
public final JdbcSpies.SpyContext spyContext = new JdbcSpies.SpyContext();
private final List<Connection> acquiredConnections = new ArrayList<>( );
private final List<Connection> releasedConnections = new ArrayList<>( );
@ -53,7 +47,7 @@ public class PreparedStatementSpyConnectionProvider extends AgroalConnectionProv
public void closeConnection(Connection conn) throws SQLException {
acquiredConnections.remove( conn );
releasedConnections.add( conn );
super.closeConnection( (Connection) MockUtil.getMockSettings( conn ).getSpiedInstance() );
super.closeConnection( spyContext.getSpiedInstance( conn ) );
}
@Override
@ -63,29 +57,7 @@ public class PreparedStatementSpyConnectionProvider extends AgroalConnectionProv
}
private Connection spy(Connection connection) {
if ( MockUtil.isMock( connection ) ) {
return connection;
}
Connection connectionSpy = Mockito.spy( connection );
try {
Mockito.doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0];
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( ArgumentMatchers.anyString() );
Mockito.doAnswer( invocation -> {
Statement statement = (Statement) invocation.callRealMethod();
Statement statementSpy = Mockito.spy( statement );
return statementSpy;
} ).when( connectionSpy ).createStatement();
}
catch ( SQLException e ) {
throw new IllegalArgumentException( e );
}
return connectionSpy;
return JdbcSpies.spy( connection, spyContext );
}
/**
@ -94,8 +66,7 @@ public class PreparedStatementSpyConnectionProvider extends AgroalConnectionProv
public void clear() {
acquiredConnections.clear();
releasedConnections.clear();
preparedStatementMap.keySet().forEach( Mockito::reset );
preparedStatementMap.clear();
spyContext.clear();
}
/**

View File

@ -1,8 +1,8 @@
package org.hibernate.orm.test.insertordering;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.AvailableSettings;
@ -14,18 +14,13 @@ import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.BaseSessionFactoryFunctionalTest;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.junit.jupiter.api.AfterAll;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Nathan Xu
*/
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
abstract class BaseInsertOrderingTest extends BaseSessionFactoryFunctionalTest {
static class Batch {
@ -43,8 +38,6 @@ abstract class BaseInsertOrderingTest extends BaseSessionFactoryFunctionalTest {
}
private final PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
@Override
@ -80,10 +73,18 @@ abstract class BaseInsertOrderingTest extends BaseSessionFactoryFunctionalTest {
for ( Batch expectedBatch : expectedBatches ) {
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( expectedBatch.sql );
try {
verify( preparedStatement, times( expectedBatch.size ) ).addBatch();
verify( preparedStatement, times( 1 ) ).executeBatch();
List<Object[]> addBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "addBatch" ),
preparedStatement
);
List<Object[]> executeBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "executeBatch" ),
preparedStatement
);
assertThat( addBatchCalls.size() ).isEqualTo( expectedBatch.size );
assertThat( executeBatchCalls.size() ).isEqualTo( 1 );
}
catch (SQLException e) {
catch (Exception e) {
throw new RuntimeException( e );
}
}

View File

@ -39,8 +39,6 @@ import static org.junit.jupiter.api.Assertions.fail;
public class AggressiveReleaseTest extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
false,
false,
true
);

View File

@ -7,7 +7,6 @@
package org.hibernate.orm.test.jdbc.internal;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import jakarta.persistence.Entity;
@ -17,24 +16,19 @@ import org.hibernate.Session;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.DialectChecks;
import org.hibernate.testing.RequiresDialectFeature;
import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@RequiresDialectFeature(DialectChecks.SupportsJdbcDriverProxying.class)
public class SessionJdbcBatchTest
extends BaseNonConfigCoreFunctionalTestCase {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider( true, false );
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider();
@Override
protected Class<?>[] getAnnotatedClasses() {
@ -77,7 +71,7 @@ public class SessionJdbcBatchTest
private long id;
@Test
public void testSessionFactorySetting() throws SQLException {
public void testSessionFactorySetting() throws Throwable {
Session session = sessionFactory().openSession();
session.beginTransaction();
try {
@ -90,13 +84,21 @@ public class SessionJdbcBatchTest
}
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement(
"insert into Event (name,id) values (?,?)" );
verify( preparedStatement, times( 5 ) ).addBatch();
verify( preparedStatement, times( 3 ) ).executeBatch();
List<Object[]> addBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "addBatch" ),
preparedStatement
);
List<Object[]> executeBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "executeBatch" ),
preparedStatement
);
assertEquals( 5, addBatchCalls.size() );
assertEquals( 3, executeBatchCalls.size() );
}
@Test
public void testSessionSettingOverridesSessionFactorySetting()
throws SQLException {
throws Throwable {
Session session = sessionFactory().openSession();
session.setJdbcBatchSize( 3 );
session.beginTransaction();
@ -110,8 +112,16 @@ public class SessionJdbcBatchTest
}
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name,id) values (?,?)" );
verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 2 )).executeBatch();
List<Object[]> addBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "addBatch" ),
preparedStatement
);
List<Object[]> executeBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "executeBatch" ),
preparedStatement
);
assertEquals( 5, addBatchCalls.size() );
assertEquals( 2, executeBatchCalls.size() );
session = sessionFactory().openSession();
session.setJdbcBatchSize( null );
@ -127,8 +137,16 @@ public class SessionJdbcBatchTest
List<PreparedStatement> preparedStatements = connectionProvider.getPreparedStatements();
assertEquals(1, preparedStatements.size());
preparedStatement = preparedStatements.get( 0 );
verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 3 )).executeBatch();
addBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "addBatch" ),
preparedStatement
);
executeBatchCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "executeBatch" ),
preparedStatement
);
assertEquals( 5, addBatchCalls.size() );
assertEquals( 3, executeBatchCalls.size() );
}
private void addEvents(Session session) {

View File

@ -32,11 +32,11 @@ import static org.junit.Assert.fail;
/**
* @author Andrea Boriero
*/
@RequiresDialectFeature({DialectChecks.SupportsJdbcDriverProxying.class, DialectChecks.SupportsLockTimeouts.class})
@RequiresDialectFeature({DialectChecks.SupportsLockTimeouts.class})
@SkipForDialect(value = CockroachDialect.class, comment = "for update clause does not imply locking. See https://github.com/cockroachdb/cockroach/issues/88995")
public class StatementIsClosedAfterALockExceptionTest extends BaseEntityManagerFunctionalTestCase {
private static final PreparedStatementSpyConnectionProvider CONNECTION_PROVIDER = new PreparedStatementSpyConnectionProvider( false, false );
private static final PreparedStatementSpyConnectionProvider CONNECTION_PROVIDER = new PreparedStatementSpyConnectionProvider();
private Integer lockId;

View File

@ -27,8 +27,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class StoreProcedureStatementsClosedTest extends BaseSessionFactoryFunctionalTest {
private final PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
@Override

View File

@ -6,8 +6,9 @@
*/
package org.hibernate.orm.test.query;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.List;
import java.util.Map;
import org.hibernate.cfg.AvailableSettings;
@ -34,8 +35,7 @@ import jakarta.persistence.Table;
import static org.hibernate.jpa.SpecHints.HINT_SPEC_QUERY_TIMEOUT;
import static org.hibernate.testing.transaction.TransactionUtil.doInHibernate;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Gail Badner
@ -44,8 +44,6 @@ import static org.mockito.Mockito.verify;
public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
private static final PreparedStatementSpyConnectionProvider CONNECTION_PROVIDER = new PreparedStatementSpyConnectionProvider(
true,
false
);
private static final String QUERY = "update AnEntity set name='abc'";
@ -93,12 +91,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify(
CONNECTION_PROVIDER.getPreparedStatement( expectedSqlQuery ),
times( 1 )
).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( expectedSqlQuery )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}
@ -115,12 +116,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify(
CONNECTION_PROVIDER.getPreparedStatement( expectedSqlQuery ),
times( 1 )
).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( expectedSqlQuery )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}
@ -137,9 +141,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify( CONNECTION_PROVIDER.getPreparedStatement( QUERY ), times( 1 ) ).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( QUERY )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}
@ -156,9 +166,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify( CONNECTION_PROVIDER.getPreparedStatement( QUERY ), times( 1 ) ).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( QUERY )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}
@ -175,9 +191,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify( CONNECTION_PROVIDER.getPreparedStatement( QUERY ), times( 1 ) ).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( QUERY )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}
@ -194,9 +216,15 @@ public class QueryTimeOutTest extends BaseNonConfigCoreFunctionalTestCase {
query.executeUpdate();
try {
verify( CONNECTION_PROVIDER.getPreparedStatement( QUERY ), times( 1 ) ).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = CONNECTION_PROVIDER.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
CONNECTION_PROVIDER.getPreparedStatement( QUERY )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException ex) {
catch (Exception ex) {
fail( "should not have thrown exception" );
}
}

View File

@ -13,7 +13,8 @@
*/
package org.hibernate.orm.test.query.criteria;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import org.hibernate.Session;
@ -24,6 +25,7 @@ import org.hibernate.query.sqm.tree.insert.SqmInsertSelectStatement;
import org.hibernate.query.sqm.tree.select.SqmSelectStatement;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProviderSettingProvider;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.EntityManagerFactoryScope;
import org.hibernate.testing.orm.junit.JiraKey;
@ -31,6 +33,7 @@ import org.hibernate.testing.orm.junit.Jpa;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.Setting;
import org.hibernate.testing.orm.junit.SettingProvider;
import org.junit.Assert;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -44,9 +47,7 @@ import jakarta.persistence.criteria.CriteriaDelete;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.CriteriaUpdate;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
@ -58,7 +59,7 @@ import static org.mockito.Mockito.verify;
settingProviders = {
@SettingProvider(
settingName = AvailableSettings.CONNECTION_PROVIDER,
provider = CriteriaTimeoutTest.SpyConnectionProviderSettingProvider.class)
provider = PreparedStatementSpyConnectionProviderSettingProvider.class)
}
)
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
@ -150,13 +151,16 @@ public class CriteriaTimeoutTest {
private void verifyQuerySetTimeoutWasCalled() {
try {
verify(
connectionProvider.getPreparedStatements().get( 0 ),
times( 1 )
).setQueryTimeout( 123 );
List<Object[]> setQueryTimeoutCalls = connectionProvider.spyContext.getCalls(
Statement.class.getMethod( "setQueryTimeout", int.class ),
connectionProvider.getPreparedStatements().get( 0 )
);
assertEquals( 2, setQueryTimeoutCalls.size() );
assertEquals( 123, setQueryTimeoutCalls.get( 0 )[0] );
assertEquals( 0, setQueryTimeoutCalls.get( 1 )[0] );
}
catch (SQLException e) {
fail( "should not have thrown exception" );
catch (Exception ex) {
Assert.fail( "should not have thrown exception" );
}
}
@ -182,11 +186,4 @@ public class CriteriaTimeoutTest {
}
}
public static class SpyConnectionProviderSettingProvider
implements SettingProvider.Provider<PreparedStatementSpyConnectionProvider> {
@Override
public PreparedStatementSpyConnectionProvider getSetting() {
return new PreparedStatementSpyConnectionProvider( true, false );
}
}
}

View File

@ -10,27 +10,23 @@ import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityManagerFactory;
import jakarta.persistence.Id;
import javax.sql.DataSource;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.dialect.Dialect;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.DialectContext;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.EntityManagerFactoryBasedFunctionalTest;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
@ -39,7 +35,7 @@ import static org.mockito.Mockito.verify;
public abstract class AbstractSkipAutoCommitTest extends EntityManagerFactoryBasedFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider =
new PreparedStatementSpyConnectionProvider( false, true ) {
new PreparedStatementSpyConnectionProvider() {
@Override
protected Connection actualConnection() throws SQLException {
Connection connection = super.actualConnection();
@ -84,7 +80,7 @@ public abstract class AbstractSkipAutoCommitTest extends EntityManagerFactoryBas
}
@Test
public void test() {
public void test() throws Throwable {
inTransaction(
entityManager -> {
// Moved inside the transaction because the new base class defers the EMF creation w/ respect to the
@ -113,18 +109,17 @@ public abstract class AbstractSkipAutoCommitTest extends EntityManagerFactoryBas
verifyConnections();
}
private void verifyConnections() {
private void verifyConnections() throws Throwable {
assertTrue( connectionProvider.getAcquiredConnections().isEmpty() );
List<Connection> connections = connectionProvider.getReleasedConnections();
assertEquals( 1, connections.size() );
Connection connection = connections.get( 0 );
try {
verify(connection, never()).setAutoCommit( false );
}
catch (SQLException e) {
fail(e.getMessage());
}
List<Object[]> setAutoCommitCalls = connectionProvider.spyContext.getCalls(
Connection.class.getMethod( "setAutoCommit", boolean.class ),
connection
);
assertEquals( 0, setAutoCommitCalls.size() );
}
@Entity(name = "City" )

View File

@ -8,12 +8,12 @@ package org.hibernate.orm.test.timestamp;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Time;
import java.time.Instant;
import java.time.OffsetTime;
import java.util.Calendar;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.TimeUnit;
@ -24,34 +24,23 @@ import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.BaseSessionFactoryFunctionalTest;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.SkipForDialect;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.mockito.ArgumentCaptor;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@SkipForDialect(dialectClass = MySQLDialect.class, matchSubTypes = true)
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
public class JdbcTimeCustomTimeZoneTest
extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
private static final TimeZone TIME_ZONE = TimeZone.getTimeZone(
@ -84,7 +73,7 @@ public class JdbcTimeCustomTimeZoneTest
}
@Test
public void testTimeZone() {
public void testTimeZone() throws Throwable {
connectionProvider.clear();
inTransaction( s -> {
@ -97,22 +86,15 @@ public class JdbcTimeCustomTimeZoneTest
assertEquals( 1, connectionProvider.getPreparedStatements().size() );
PreparedStatement ps = connectionProvider.getPreparedStatements()
.get( 0 );
try {
ArgumentCaptor<Calendar> calendarArgumentCaptor = ArgumentCaptor.forClass(
Calendar.class );
verify( ps, times( 1 ) ).setTime(
anyInt(),
any( Time.class ),
calendarArgumentCaptor.capture()
);
assertEquals(
TIME_ZONE,
calendarArgumentCaptor.getValue().getTimeZone()
);
}
catch (SQLException e) {
fail( e.getMessage() );
}
List<Object[]> setTimeCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "setTime", int.class, Time.class, Calendar.class ),
ps
);
assertEquals( 1, setTimeCalls.size() );
assertEquals(
TIME_ZONE,
( (Calendar) setTimeCalls.get( 0 )[2] ).getTimeZone()
);
connectionProvider.clear();
inTransaction( s -> {

View File

@ -7,8 +7,8 @@
package org.hibernate.orm.test.timestamp;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Time;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
@ -26,11 +26,6 @@ import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
@ -40,8 +35,6 @@ public class JdbcTimeDefaultTimeZoneTest
extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
@Override
@ -67,7 +60,7 @@ public class JdbcTimeDefaultTimeZoneTest
}
@Test
public void testTimeZone() {
public void testTimeZone() throws Throwable {
connectionProvider.clear();
inTransaction( s -> {
@ -80,12 +73,11 @@ public class JdbcTimeDefaultTimeZoneTest
assertEquals( 1, connectionProvider.getPreparedStatements().size() );
PreparedStatement ps = connectionProvider.getPreparedStatements()
.get( 0 );
try {
verify( ps, times( 1 ) ).setTime( anyInt(), any( Time.class ) );
}
catch (SQLException e) {
fail( e.getMessage() );
}
List<Object[]> setTimeCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "setTime", int.class, Time.class ),
ps
);
assertEquals( 1, setTimeCalls.size() );
inTransaction( s -> {
Person person = s.find( Person.class, 1L );

View File

@ -8,10 +8,10 @@ package org.hibernate.orm.test.timestamp;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Timestamp;
import java.util.Calendar;
import java.util.List;
import java.util.TimeZone;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
@ -21,32 +21,23 @@ import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.BaseSessionFactoryFunctionalTest;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.SkipForDialect;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.mockito.ArgumentCaptor;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@SkipForDialect(dialectClass = MySQLDialect.class, matchSubTypes = true)
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
public class JdbcTimestampCustomSessionLevelTimeZoneTest
extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider( true, false );
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider();
private static final TimeZone TIME_ZONE = TimeZone.getTimeZone(
"America/Los_Angeles" );
@ -73,7 +64,7 @@ public class JdbcTimestampCustomSessionLevelTimeZoneTest
}
@Test
public void testTimeZone() {
public void testTimeZone() throws Throwable {
connectionProvider.clear();
doInHibernateSessionBuilder( () -> {
@ -88,22 +79,15 @@ public class JdbcTimestampCustomSessionLevelTimeZoneTest
assertEquals( 1, connectionProvider.getPreparedStatements().size() );
PreparedStatement ps = connectionProvider.getPreparedStatements()
.get( 0 );
try {
ArgumentCaptor<Calendar> calendarArgumentCaptor = ArgumentCaptor.forClass(
Calendar.class );
verify( ps, times( 1 ) ).setTimestamp(
anyInt(),
any( Timestamp.class ),
calendarArgumentCaptor.capture()
);
assertEquals(
TIME_ZONE,
calendarArgumentCaptor.getValue().getTimeZone()
);
}
catch ( SQLException e ) {
fail( e.getMessage() );
}
List<Object[]> setTimeCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "setTimestamp", int.class, Timestamp.class, Calendar.class ),
ps
);
assertEquals( 1, setTimeCalls.size() );
assertEquals(
TIME_ZONE,
( (Calendar) setTimeCalls.get( 0 )[2] ).getTimeZone()
);
connectionProvider.clear();
doInHibernateSessionBuilder( () -> {

View File

@ -8,10 +8,10 @@ package org.hibernate.orm.test.timestamp;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Timestamp;
import java.util.Calendar;
import java.util.List;
import java.util.TimeZone;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
@ -21,34 +21,23 @@ import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.BaseSessionFactoryFunctionalTest;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.hibernate.testing.orm.junit.SkipForDialect;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.mockito.ArgumentCaptor;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@SkipForDialect(dialectClass = MySQLDialect.class, matchSubTypes = true)
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
public class JdbcTimestampCustomTimeZoneTest
extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
private static final TimeZone TIME_ZONE = TimeZone.getTimeZone(
@ -81,7 +70,7 @@ public class JdbcTimestampCustomTimeZoneTest
}
@Test
public void testTimeZone() {
public void testTimeZone() throws Throwable {
connectionProvider.clear();
inTransaction( s -> {
@ -94,22 +83,15 @@ public class JdbcTimestampCustomTimeZoneTest
assertEquals( 1, connectionProvider.getPreparedStatements().size() );
PreparedStatement ps = connectionProvider.getPreparedStatements()
.get( 0 );
try {
ArgumentCaptor<Calendar> calendarArgumentCaptor = ArgumentCaptor.forClass(
Calendar.class );
verify( ps, times( 1 ) ).setTimestamp(
anyInt(),
any( Timestamp.class ),
calendarArgumentCaptor.capture()
);
assertEquals(
TIME_ZONE,
calendarArgumentCaptor.getValue().getTimeZone()
);
}
catch (SQLException e) {
fail( e.getMessage() );
}
List<Object[]> setTimeCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "setTimestamp", int.class, Timestamp.class, Calendar.class ),
ps
);
assertEquals( 1, setTimeCalls.size() );
assertEquals(
TIME_ZONE,
( (Calendar) setTimeCalls.get( 0 )[2] ).getTimeZone()
);
connectionProvider.clear();
inTransaction( s -> {

View File

@ -7,8 +7,8 @@
package org.hibernate.orm.test.timestamp;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.List;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.AvailableSettings;
@ -16,8 +16,6 @@ import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.orm.jdbc.PreparedStatementSpyConnectionProvider;
import org.hibernate.testing.orm.junit.BaseSessionFactoryFunctionalTest;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
import org.hibernate.testing.orm.junit.RequiresDialectFeature;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
@ -25,22 +23,14 @@ import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@RequiresDialectFeature(feature = DialectFeatureChecks.SupportsJdbcDriverProxying.class)
public class JdbcTimestampDefaultTimeZoneTest
extends BaseSessionFactoryFunctionalTest {
private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider(
true,
false
);
@Override
@ -66,7 +56,7 @@ public class JdbcTimestampDefaultTimeZoneTest
}
@Test
public void testTimeZone() {
public void testTimeZone() throws Throwable {
connectionProvider.clear();
inTransaction( s -> {
@ -79,15 +69,11 @@ public class JdbcTimestampDefaultTimeZoneTest
assertEquals( 1, connectionProvider.getPreparedStatements().size() );
PreparedStatement ps = connectionProvider.getPreparedStatements()
.get( 0 );
try {
verify( ps, times( 1 ) ).setTimestamp(
anyInt(),
any( Timestamp.class )
);
}
catch (SQLException e) {
fail( e.getMessage() );
}
List<Object[]> setTimeCalls = connectionProvider.spyContext.getCalls(
PreparedStatement.class.getMethod( "setTimestamp", int.class, Timestamp.class ),
ps
);
assertEquals( 1, setTimeCalls.size() );
inTransaction( s -> {
Person person = s.find( Person.class, 1L );

View File

@ -9,21 +9,14 @@ package org.hibernate.orm.test.util;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.jdbc.ConnectionProviderDelegate;
import org.hibernate.testing.jdbc.JdbcSpies;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
/**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
@ -33,11 +26,7 @@ import org.mockito.internal.util.MockUtil;
*/
public class PreparedStatementSpyConnectionProvider
extends ConnectionProviderDelegate {
// We must keep around the mocked connections, otherwise the are garbage collected and trigger finalizers
// Since we use CALLS_REAL_METHODS this might close underlying IO resources which make other objects unusable
private static final Queue<Object> MOCKS = new LinkedBlockingQueue<>();
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
public final JdbcSpies.SpyContext spyContext = new JdbcSpies.SpyContext();
private final List<Connection> acquiredConnections = new ArrayList<>( );
private final List<Connection> releasedConnections = new ArrayList<>( );
@ -52,7 +41,6 @@ public class PreparedStatementSpyConnectionProvider
@Override
public Connection getConnection() throws SQLException {
Connection connection = spy( actualConnection() );
MOCKS.add( connection );
acquiredConnections.add( connection );
return connection;
}
@ -61,7 +49,7 @@ public class PreparedStatementSpyConnectionProvider
public void closeConnection(Connection conn) throws SQLException {
acquiredConnections.remove( conn );
releasedConnections.add( conn );
super.closeConnection( (Connection) MockUtil.getMockSettings( conn ).getSpiedInstance() );
super.closeConnection( spyContext.getSpiedInstance( conn ) );
}
@Override
@ -71,29 +59,7 @@ public class PreparedStatementSpyConnectionProvider
}
private Connection spy(Connection connection) {
if ( MockUtil.isMock( connection ) ) {
return connection;
}
Connection connectionSpy = Mockito.spy( connection );
try {
Mockito.doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0];
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( ArgumentMatchers.anyString() );
Mockito.doAnswer( invocation -> {
Statement statement = (Statement) invocation.callRealMethod();
Statement statementSpy = Mockito.spy( statement );
return statementSpy;
} ).when( connectionSpy ).createStatement();
}
catch ( SQLException e ) {
throw new IllegalArgumentException( e );
}
return connectionSpy;
return JdbcSpies.spy( connection, spyContext );
}
/**
@ -102,8 +68,7 @@ public class PreparedStatementSpyConnectionProvider
public void clear() {
acquiredConnections.clear();
releasedConnections.clear();
preparedStatementMap.keySet().forEach( Mockito::reset );
preparedStatementMap.clear();
spyContext.clear();
}
/**

View File

@ -7,7 +7,6 @@
package org.hibernate.test.hikaricp;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
@ -15,26 +14,21 @@ import jakarta.persistence.Id;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.cfg.Configuration;
import org.hibernate.testing.DialectChecks;
import org.hibernate.testing.RequiresDialectFeature;
import org.hibernate.testing.SkipForDialect;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.hibernate.dialect.SybaseDialect;
import org.hibernate.orm.test.util.PreparedStatementSpyConnectionProvider;
import org.junit.Test;
import static org.hibernate.testing.transaction.TransactionUtil.doInHibernate;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/**
* @author Vlad Mihalcea
*/
@RequiresDialectFeature(DialectChecks.SupportsJdbcDriverProxying.class)
@SkipForDialect(value = SybaseDialect.class, comment = "The jTDS driver doesn't implement Connection#isValid so this fails")
public class HikariCPSkipAutoCommitTest extends BaseCoreFunctionalTestCase {
@ -91,12 +85,15 @@ public class HikariCPSkipAutoCommitTest extends BaseCoreFunctionalTestCase {
List<Connection> connections = connectionProvider.getReleasedConnections();
assertEquals( 1, connections.size() );
Connection connection = connections.get( 0 );
try {
verify(connection, never()).setAutoCommit( false );
List<Object[]> setAutoCommitCalls = connectionProvider.spyContext.getCalls(
Connection.class.getMethod( "setAutoCommit", boolean.class ),
connections.get( 0 )
);
assertTrue( "setAutoCommit should never be called", setAutoCommitCalls.isEmpty() );
}
catch (SQLException e) {
fail(e.getMessage());
catch (NoSuchMethodException e) {
throw new RuntimeException( e );
}
}

View File

@ -0,0 +1,578 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.testing.jdbc;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.Savepoint;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
/**
* @author Christian Beikov
*/
@SuppressWarnings({"unused"})
public class JdbcSpies {
public interface Callback {
void onCall(Object spy, Method method, Object[] args, Object result);
}
public static class SpyContext {
private final Map<Method, Map<Object, List<Object[]>>> calls = new HashMap<>();
private final List<Callback> callbacks = new ArrayList<>();
private Object call(Object spy, Object realObject, Method method, Object[] args) throws Throwable {
return onCall( spy, method, args, callOnly( realObject, method, args ) );
}
private Object callOnly(Object realObject, Method method, Object[] args) throws Throwable {
try {
return method.invoke( realObject, args );
}
catch (InvocationTargetException e) {
throw e.getTargetException();
}
}
private <T> T onCall(Object spy, Method method, Object[] args, T result) {
calls.computeIfAbsent( method, m -> new IdentityHashMap<>() )
.computeIfAbsent( spy, s -> new ArrayList<>() )
.add( args );
for ( Callback callback : callbacks ) {
callback.onCall( spy, method, args, result );
}
return result;
}
public SpyContext registerCallback(Callback callback) {
callbacks.add( callback );
return this;
}
public List<Object[]> getCalls(Method method, Object spy) {
return calls.getOrDefault( method, Collections.emptyMap() ).getOrDefault( spy, Collections.emptyList() );
}
public void clear() {
calls.clear();
}
public <T> T getSpiedInstance(T spy) {
if ( Proxy.isProxyClass( spy.getClass() ) ) {
final InvocationHandler invocationHandler = Proxy.getInvocationHandler( spy );
if ( invocationHandler instanceof Spy ) {
//noinspection unchecked
return (T) ( (Spy) invocationHandler ).getSpiedInstance();
}
}
throw new IllegalArgumentException( "Passed object is not a spy: " + spy );
}
}
public static Connection spy(Connection connection, SpyContext context) {
return (Connection) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[]{ Connection.class },
new ConnectionHandler( connection, context )
);
}
private interface Spy {
Object getSpiedInstance();
}
private static class ConnectionHandler implements InvocationHandler, Spy {
private final Connection connection;
private final SpyContext context;
private DatabaseMetaData databaseMetaDataProxy;
public ConnectionHandler(Connection connection, SpyContext context) {
this.connection = connection;
this.context = context;
}
@Override
public Object getSpiedInstance() {
return connection;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "getMetaData":
return context.onCall( proxy, method, args, getDatabaseMetaDataProxy( (Connection) proxy ) );
case "createStatement":
return context.onCall(
proxy,
method,
args,
createStatementProxy(
(Statement) context.callOnly( connection, method, args ),
(Connection) proxy
)
);
case "prepareStatement":
return context.onCall(
proxy,
method,
args,
prepareStatementProxy(
(PreparedStatement) context.callOnly( connection, method, args ),
(Connection) proxy
)
);
case "prepareCall":
return context.onCall(
proxy,
method,
args,
prepareCallProxy(
(CallableStatement) context.callOnly( connection, method, args ),
(Connection) proxy
)
);
case "toString":
return context.onCall( proxy, method, args, "Connection proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
case "setSavepoint":
return savepointProxy( (Savepoint) context.call( proxy, connection, method, args ) );
case "releaseSavepoint":
if ( Proxy.isProxyClass( args[0].getClass() ) ) {
args[0] = ( (SavepointHandler) Proxy.getInvocationHandler( args[0] ) ).savepoint;
}
return context.call( proxy, connection, method, args );
case "rollback":
if ( args != null && args.length != 0 && Proxy.isProxyClass( args[0].getClass() ) ) {
args[0] = ( (SavepointHandler) Proxy.getInvocationHandler( args[0] ) ).savepoint;
}
return context.call( proxy, connection, method, args );
default:
return context.call( proxy, connection, method, args );
}
}
private DatabaseMetaData getDatabaseMetaDataProxy(Connection connectionProxy) throws Throwable {
if ( databaseMetaDataProxy == null ) {
// we need to make it
final DatabaseMetaDataHandler metadataHandler = new DatabaseMetaDataHandler( connection.getMetaData(), connectionProxy, context );
databaseMetaDataProxy = (DatabaseMetaData) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {DatabaseMetaData.class},
metadataHandler
);
}
return databaseMetaDataProxy;
}
private Statement createStatementProxy(Statement statement, Connection connectionProxy) {
return (Statement) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {Statement.class},
new StatementHandler( statement, context, connectionProxy )
);
}
private PreparedStatement prepareStatementProxy(PreparedStatement statement, Connection connectionProxy) {
return (PreparedStatement) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {PreparedStatement.class},
new PreparedStatementHandler( statement, context, connectionProxy )
);
}
private CallableStatement prepareCallProxy(CallableStatement statement, Connection connectionProxy) {
return (CallableStatement) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {CallableStatement.class},
new CallableStatementHandler( statement, context, connectionProxy )
);
}
private Savepoint savepointProxy(Savepoint savepoint) {
return (Savepoint) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {Savepoint.class},
new SavepointHandler( savepoint, context )
);
}
}
private static class StatementHandler implements InvocationHandler, Spy {
protected final Statement statement;
protected final SpyContext context;
protected final Connection connectionProxy;
public StatementHandler(Statement statement, SpyContext context, Connection connectionProxy) {
this.statement = statement;
this.context = context;
this.connectionProxy = connectionProxy;
}
@Override
public Object getSpiedInstance() {
return statement;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "getConnection":
return context.onCall( proxy, method, args, connectionProxy );
case "toString":
return context.onCall( proxy, method, args, "Statement proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
case "executeQuery":
return context.onCall( proxy, method, args, getResultSetProxy( statement.executeQuery( (String) args[0] ), (Statement) proxy ) );
case "getResultSet":
return context.onCall( proxy, method, args, getResultSetProxy( statement.getResultSet(), (Statement) proxy ) );
case "getGeneratedKeys":
return context.onCall( proxy, method, args, getResultSetProxy( statement.getGeneratedKeys(), (Statement) proxy ) );
default:
return context.call( proxy, statement, method, args );
}
}
protected ResultSet getResultSetProxy(ResultSet resultSet, Statement statementProxy) throws Throwable {
return (ResultSet) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {ResultSet.class},
new ResultSetHandler( resultSet, context, statementProxy )
);
}
}
private static class PreparedStatementHandler extends StatementHandler {
private ResultSetMetaData resultSetMetaDataProxy;
private ParameterMetaData parameterMetaDataProxy;
public PreparedStatementHandler(PreparedStatement statement, SpyContext context, Connection connectionProxy) {
super( statement, context, connectionProxy );
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "toString":
return context.onCall( proxy, method, args, "PreparedStatement proxy [@" + hashCode() + "]" );
case "executeQuery":
return context.onCall(
proxy,
method,
args,
getResultSetProxy(
(ResultSet) context.callOnly( statement, method, args ),
(PreparedStatement) proxy
)
);
case "getMetaData":
return context.onCall(
proxy,
method,
args,
getResultSetMetaDataProxy( ( (PreparedStatement) statement ).getMetaData() )
);
case "getParameterMetaData":
return context.onCall(
proxy,
method,
args,
getParameterMetaDataProxy( ( (PreparedStatement) statement ).getParameterMetaData() )
);
default:
return super.invoke( proxy, method, args );
}
}
private ResultSetMetaData getResultSetMetaDataProxy(ResultSetMetaData metaData) throws Throwable {
if ( resultSetMetaDataProxy == null ) {
// we need to make it
resultSetMetaDataProxy = (ResultSetMetaData) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {ResultSetMetaData.class},
new ResultSetMetaDataHandler( metaData, context )
);
}
return resultSetMetaDataProxy;
}
private ParameterMetaData getParameterMetaDataProxy(ParameterMetaData metaData) throws Throwable {
if ( parameterMetaDataProxy == null ) {
// we need to make it
parameterMetaDataProxy = (ParameterMetaData) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {ParameterMetaData.class},
new ParameterMetaDataHandler( metaData, context )
);
}
return parameterMetaDataProxy;
}
}
private static class CallableStatementHandler extends PreparedStatementHandler {
public CallableStatementHandler(CallableStatement statement, SpyContext context, Connection connectionProxy) {
super( statement, context, connectionProxy );
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if ( "toString".equals( method.getName() ) ) {
return context.onCall( proxy, method, args, "CallableStatement proxy [@" + hashCode() + "]" );
}
else {
return super.invoke( proxy, method, args );
}
}
}
private static class DatabaseMetaDataHandler implements InvocationHandler, Spy {
private final DatabaseMetaData databaseMetaData;
private final Connection connectionProxy;
private final SpyContext context;
public DatabaseMetaDataHandler(
DatabaseMetaData databaseMetaData,
Connection connectionProxy,
SpyContext context) {
this.databaseMetaData = databaseMetaData;
this.connectionProxy = connectionProxy;
this.context = context;
}
@Override
public Object getSpiedInstance() {
return databaseMetaData;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "getConnection":
return context.onCall( proxy, method, args, connectionProxy );
case "toString":
return context.onCall( proxy, method, args, "DatabaseMetaData proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
case "getProcedures":
case "getProcedureColumns":
case "getTables":
case "getSchemas":
case "getCatalogs":
case "getTableTypes":
case "getColumns":
case "getColumnPrivileges":
case "getTablePrivileges":
case "getBestRowIdentifier":
case "getVersionColumns":
case "getPrimaryKeys":
case "getImportedKeys":
case "getExportedKeys":
case "getCrossReference":
case "getTypeInfo":
case "getIndexInfo":
case "getUDTs":
case "getSuperTypes":
case "getSuperTables":
case "getAttributes":
case "getClientInfoProperties":
case "getFunctions":
case "getFunctionColumns":
case "getPseudoColumns":
final ResultSet resultSet = (ResultSet) context.callOnly( databaseMetaData, method, args );
return context.onCall( proxy, method, args, getResultSetProxy( resultSet, getStatementProxy( resultSet.getStatement() ) ) );
default:
return context.call( proxy, databaseMetaData, method, args );
}
}
protected ResultSet getResultSetProxy(ResultSet resultSet, Statement statementProxy) throws Throwable {
return (ResultSet) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {ResultSet.class},
new ResultSetHandler( resultSet, context, statementProxy )
);
}
protected Statement getStatementProxy(Statement statement) throws Throwable {
final InvocationHandler handler;
if ( statement instanceof CallableStatement ) {
handler = new CallableStatementHandler( (CallableStatement) statement, context, connectionProxy );
}
else if ( statement instanceof PreparedStatement ) {
handler = new PreparedStatementHandler( (PreparedStatement) statement, context, connectionProxy );
}
else {
handler = new StatementHandler( statement, context, connectionProxy );
}
return (Statement) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {Statement.class},
handler
);
}
}
private static class ParameterMetaDataHandler implements InvocationHandler, Spy {
private final ParameterMetaData parameterMetaData;
private final SpyContext context;
public ParameterMetaDataHandler(ParameterMetaData parameterMetaData, SpyContext context) {
this.parameterMetaData = parameterMetaData;
this.context = context;
}
@Override
public Object getSpiedInstance() {
return parameterMetaData;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "toString":
return context.onCall( proxy, method, args, "DatabaseMetaData proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
default:
return context.call( proxy, parameterMetaData, method, args );
}
}
}
private static class ResultSetHandler implements InvocationHandler, Spy {
private final ResultSet resultSet;
private final SpyContext context;
private final Statement statementProxy;
private ResultSetMetaData metadataProxy;
public ResultSetHandler(ResultSet resultSet, SpyContext context, Statement statementProxy) {
this.resultSet = resultSet;
this.context = context;
this.statementProxy = statementProxy;
}
@Override
public Object getSpiedInstance() {
return resultSet;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
final String methodName = method.getName();
switch ( methodName ) {
case "getMetaData":
return context.onCall( proxy, method, args, getResultSetMetaDataProxy( resultSet.getMetaData() ) );
case "getStatement":
return context.onCall( proxy, method, args, statementProxy );
case "toString":
return context.onCall( proxy, method, args, "ResultSet proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
default:
return context.call( proxy, resultSet, method, args );
}
}
private ResultSetMetaData getResultSetMetaDataProxy(ResultSetMetaData metaData) throws Throwable {
if ( metadataProxy == null ) {
// we need to make it
final ResultSetMetaDataHandler metadataHandler = new ResultSetMetaDataHandler( metaData, context );
metadataProxy = (ResultSetMetaData) Proxy.newProxyInstance(
ClassLoader.getSystemClassLoader(),
new Class[] {ResultSetMetaData.class},
metadataHandler
);
}
return metadataProxy;
}
}
private static class ResultSetMetaDataHandler implements InvocationHandler, Spy {
private final ResultSetMetaData resultSetMetaData;
private final SpyContext context;
public ResultSetMetaDataHandler(ResultSetMetaData resultSetMetaData, SpyContext context) {
this.resultSetMetaData = resultSetMetaData;
this.context = context;
}
@Override
public Object getSpiedInstance() {
return resultSetMetaData;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "toString":
return context.onCall( proxy, method, args, "ResultSetMetaData proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
default:
return context.call( proxy, resultSetMetaData, method, args );
}
}
}
private static class SavepointHandler implements InvocationHandler, Spy {
private final Savepoint savepoint;
private final SpyContext context;
public SavepointHandler(Savepoint savepoint, SpyContext context) {
this.savepoint = savepoint;
this.context = context;
}
@Override
public Object getSpiedInstance() {
return savepoint;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
switch ( method.getName() ) {
case "toString":
return context.onCall( proxy, method, args, "Savepoint proxy [@" + hashCode() + "]" );
case "hashCode":
return context.onCall( proxy, method, args, hashCode() );
case "equals":
return context.onCall( proxy, method, args, proxy == args[0] );
default:
return context.call( proxy, savepoint, method, args );
}
}
}
}

View File

@ -14,18 +14,12 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.testing.jdbc.ConnectionProviderDelegate;
import org.mockito.ArgumentMatchers;
import org.mockito.MockSettings;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
import org.hibernate.testing.jdbc.JdbcSpies;
/**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
@ -36,57 +30,38 @@ import org.mockito.internal.util.MockUtil;
*/
public class PreparedStatementSpyConnectionProvider extends ConnectionProviderDelegate {
// We must keep around the mocked connections, otherwise they are garbage collected and trigger finalizers
// Since we use CALLS_REAL_METHODS this might close underlying IO resources which makes other objects unusable
private static final Queue<Object> MOCKS = new LinkedBlockingQueue<>();
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
private final List<String> executeStatements = new ArrayList<>( 4 );
private final List<String> executeUpdateStatements = new ArrayList<>( 4 );
public final JdbcSpies.SpyContext spyContext = new JdbcSpies.SpyContext()
.registerCallback(
(spy, method, args, result) -> {
if ( method.getDeclaringClass() == Connection.class
&& method.getName().equals( "prepareStatement" ) ) {
preparedStatementMap.put( (PreparedStatement) result, (String) args[0] );
}
else if ( method.getDeclaringClass() == Statement.class
&& method.getName().equals( "execute" ) ) {
executeStatements.add( (String) args[0] );
}
else if ( method.getDeclaringClass() == Statement.class
&& method.getName().equals( "executeUpdate" ) ) {
executeUpdateStatements.add( (String) args[0] );
}
}
);
private final List<Connection> acquiredConnections = new ArrayList<>( 4 );
private final List<Connection> releasedConnections = new ArrayList<>( 4 );
private final MockSettings settingsForStatements;
private final MockSettings settingsForConnections;
/**
* @deprecated best use the {@link #PreparedStatementSpyConnectionProvider(boolean,boolean)} method to be explicit about the limitations.
*/
@Deprecated
public PreparedStatementSpyConnectionProvider() {
this( false, false, false );
this( false );
}
/**
* Careful: the default is to use mocks which do not allow to verify invocations, as otherwise the
* memory usage of the testsuite is extremely high.
* When you really need to verify invocations, set the relevant constructor parameter to true.
*/
public PreparedStatementSpyConnectionProvider(
boolean allowMockVerificationOnStatements,
boolean allowMockVerificationOnConnections) {
this( allowMockVerificationOnStatements, allowMockVerificationOnConnections, false );
}
public PreparedStatementSpyConnectionProvider(boolean allowMockVerificationOnStatements, boolean allowMockVerificationOnConnections, boolean forceSupportsAggressiveRelease) {
public PreparedStatementSpyConnectionProvider(boolean forceSupportsAggressiveRelease) {
super(forceSupportsAggressiveRelease);
this.settingsForStatements = allowMockVerificationOnStatements ?
getVerifiableMockSettings() :
getMockSettings();
this.settingsForConnections = allowMockVerificationOnConnections ?
getVerifiableMockSettings() :
getMockSettings();
}
private static MockSettings getMockSettings() {
return Mockito.withSettings()
.stubOnly() //important optimisation: uses far less memory, at tradeoff of mocked methods no longer being verifiable but we often don't need that.
.defaultAnswer( org.mockito.Answers.CALLS_REAL_METHODS );
}
private static MockSettings getVerifiableMockSettings() {
return Mockito.withSettings().defaultAnswer( org.mockito.Answers.CALLS_REAL_METHODS );
}
protected Connection actualConnection() throws SQLException {
@ -96,7 +71,6 @@ public class PreparedStatementSpyConnectionProvider extends ConnectionProviderDe
@Override
public Connection getConnection() throws SQLException {
Connection connection = instrumentConnection( actualConnection() );
MOCKS.add( connection );
acquiredConnections.add( connection );
return connection;
}
@ -105,7 +79,7 @@ public class PreparedStatementSpyConnectionProvider extends ConnectionProviderDe
public void closeConnection(Connection conn) throws SQLException {
acquiredConnections.remove( conn );
releasedConnections.add( conn );
super.closeConnection( (Connection) MockUtil.getMockSettings( conn ).getSpiedInstance() );
super.closeConnection( spyContext.getSpiedInstance( conn ) );
}
@Override
@ -115,51 +89,7 @@ public class PreparedStatementSpyConnectionProvider extends ConnectionProviderDe
}
private Connection instrumentConnection(Connection connection) {
if ( MockUtil.isMock( connection ) ) {
return connection;
}
Connection connectionSpy = spy( connection, settingsForConnections );
try {
Mockito.doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = spy( statement, settingsForStatements );
String sql = (String) invocation.getArguments()[0];
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( ArgumentMatchers.anyString() );
Mockito.doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = spy( statement, settingsForStatements );
String sql = (String) invocation.getArguments()[0];
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareCall( ArgumentMatchers.anyString() );
Mockito.doAnswer( invocation -> {
Statement statement = (Statement) invocation.callRealMethod();
Statement statementSpy = spy( statement, settingsForStatements );
Mockito.doAnswer( statementInvocation -> {
String sql = (String) statementInvocation.getArguments()[0];
executeStatements.add( sql );
return statementInvocation.callRealMethod();
} ).when( statementSpy ).execute( ArgumentMatchers.anyString() );
Mockito.doAnswer( statementInvocation -> {
String sql = (String) statementInvocation.getArguments()[0];
executeUpdateStatements.add( sql );
return statementInvocation.callRealMethod();
} ).when( statementSpy ).executeUpdate( ArgumentMatchers.anyString() );
return statementSpy;
} ).when( connectionSpy ).createStatement();
}
catch ( SQLException e ) {
throw new IllegalArgumentException( e );
}
return connectionSpy;
}
private static <T> T spy(T subject, MockSettings mockSettings) {
return Mockito.mock( (Class<T>) subject.getClass(), mockSettings.spiedInstance( subject ) );
return JdbcSpies.spy( connection, spyContext );
}
/**
@ -168,7 +98,7 @@ public class PreparedStatementSpyConnectionProvider extends ConnectionProviderDe
public void clear() {
acquiredConnections.clear();
releasedConnections.clear();
preparedStatementMap.keySet().forEach( Mockito::reset );
spyContext.clear();
preparedStatementMap.clear();
executeStatements.clear();
executeUpdateStatements.clear();

View File

@ -14,6 +14,6 @@ import org.hibernate.testing.orm.junit.SettingProvider;
public class PreparedStatementSpyConnectionProviderSettingProvider implements SettingProvider.Provider<PreparedStatementSpyConnectionProvider> {
@Override
public PreparedStatementSpyConnectionProvider getSetting() {
return new PreparedStatementSpyConnectionProvider( false, false );
return new PreparedStatementSpyConnectionProvider();
}
}