diff --git a/hibernate-core/src/main/java/org/hibernate/SharedSessionContract.java b/hibernate-core/src/main/java/org/hibernate/SharedSessionContract.java index 5bef15102e..e92d7bdb87 100644 --- a/hibernate-core/src/main/java/org/hibernate/SharedSessionContract.java +++ b/hibernate-core/src/main/java/org/hibernate/SharedSessionContract.java @@ -183,5 +183,5 @@ public interface SharedSessionContract extends QueryProducer, Serializable { * @see org.hibernate.boot.spi.SessionFactoryOptions#getJdbcBatchSize * @see org.hibernate.boot.SessionFactoryBuilder#applyJdbcBatchSize */ - void setJdbcBatchSize(int jdbcBatchSize); + void setJdbcBatchSize(Integer jdbcBatchSize); } diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java index 3789876a3b..b2bcf71f25 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java @@ -1163,7 +1163,7 @@ public class SessionDelegatorBaseImpl implements SessionImplementor { } @Override - public void setJdbcBatchSize(int jdbcBatchSize) { + public void setJdbcBatchSize(Integer jdbcBatchSize) { delegate.setJdbcBatchSize( jdbcBatchSize ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java b/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java index a085223400..67352020a0 100644 --- a/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java +++ b/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java @@ -964,7 +964,7 @@ public abstract class AbstractSharedSessionContract implements SharedSessionCont } @Override - public void setJdbcBatchSize(int jdbcBatchSize) { + public void setJdbcBatchSize(Integer jdbcBatchSize) { this.jdbcBatchSize = jdbcBatchSize; } diff --git a/hibernate-core/src/test/java/org/hibernate/test/jdbc/internal/SessionJdbcBatchTest.java b/hibernate-core/src/test/java/org/hibernate/test/jdbc/internal/SessionJdbcBatchTest.java new file mode 100644 index 0000000000..f48c3b9ace --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/test/jdbc/internal/SessionJdbcBatchTest.java @@ -0,0 +1,143 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.test.jdbc.internal; + +import java.util.Map; +import javax.persistence.Entity; +import javax.persistence.Id; + +import org.hibernate.Session; +import org.hibernate.cfg.AvailableSettings; + +import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase; +import org.hibernate.test.util.JdbcStatisticsConnectionProvider; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author Vlad Mihalcea + */ +public class SessionJdbcBatchTest + extends BaseNonConfigCoreFunctionalTestCase { + + private JdbcStatisticsConnectionProvider connectionProvider = new JdbcStatisticsConnectionProvider(); + + @Override + protected Class[] getAnnotatedClasses() { + return new Class[] { Event.class }; + } + + @Override + protected void addSettings(Map settings) { + settings.put( AvailableSettings.STATEMENT_BATCH_SIZE, 2 ); + settings.put( + AvailableSettings.CONNECTION_PROVIDER, + connectionProvider + ); + } + + @Override + protected boolean rebuildSessionFactoryOnError() { + return false; + } + + @Override + protected boolean isCleanupTestDataRequired() { + return true; + } + + private long id; + + @Test + public void testSessionFactorySetting() { + Session session = sessionFactory().openSession(); + session.beginTransaction(); + try { + addEvents( session ); + } + finally { + connectionProvider.getPreparedStatementStatistics().clear(); + session.getTransaction().commit(); + session.close(); + } + JdbcStatisticsConnectionProvider.PreparedStatementStatistics statementStatistics = + connectionProvider.getPreparedStatementStatistics().get( + "insert into Event (name, id) values (?, ?)" ); + assertEquals( 5, statementStatistics.getAddBatchCount() ); + assertEquals( 3, statementStatistics.getExecuteBatchCount() ); + } + + @Test + public void testSessionSettingOverridesSessionFactorySetting() { + Session session = sessionFactory().openSession(); + session.setJdbcBatchSize( 3 ); + session.beginTransaction(); + try { + addEvents( session ); + } + finally { + connectionProvider.clear(); + session.getTransaction().commit(); + session.close(); + } + JdbcStatisticsConnectionProvider.PreparedStatementStatistics statementStatistics = + connectionProvider.getPreparedStatementStatistics().get( + "insert into Event (name, id) values (?, ?)" ); + assertEquals( 5, statementStatistics.getAddBatchCount() ); + assertEquals( 2, statementStatistics.getExecuteBatchCount() ); + + session = sessionFactory().openSession(); + session.setJdbcBatchSize( null ); + session.beginTransaction(); + try { + addEvents( session ); + } + finally { + connectionProvider.clear(); + session.getTransaction().commit(); + session.close(); + } + statementStatistics = + connectionProvider.getPreparedStatementStatistics().get( + "insert into Event (name, id) values (?, ?)" ); + assertEquals( 5, statementStatistics.getAddBatchCount() ); + assertEquals( 3, statementStatistics.getExecuteBatchCount() ); + } + + private void addEvents(Session session) { + for ( long i = 0; i < 5; i++ ) { + Event event = new Event(); + event.id = id++; + event.name = "Event " + i; + session.persist( event ); + } + } + + @Test + public void testSessionJdbcBatchOverridesSessionFactorySetting() { + + Session session = sessionFactory().openSession(); + session.beginTransaction(); + try { + + } + finally { + session.getTransaction().commit(); + session.close(); + } + } + + @Entity(name = "Event") + public static class Event { + + @Id + private Long id; + + private String name; + } +} diff --git a/hibernate-core/src/test/java/org/hibernate/test/util/JdbcStatisticsConnectionProvider.java b/hibernate-core/src/test/java/org/hibernate/test/util/JdbcStatisticsConnectionProvider.java new file mode 100644 index 0000000000..49e81507f5 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/test/util/JdbcStatisticsConnectionProvider.java @@ -0,0 +1,115 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.test.util; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +import org.hibernate.engine.jdbc.connections.internal.DriverManagerConnectionProviderImpl; + +import org.mockito.Mockito; +import org.mockito.internal.util.MockUtil; + +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doAnswer; + +/** + * @author Vlad Mihalcea + */ +public class JdbcStatisticsConnectionProvider extends + DriverManagerConnectionProviderImpl { + + public static class PreparedStatementStatistics { + private final String sql; + + private int executeUpdateCount; + + private int addBatchCount; + + private int executeBatchCount; + + public PreparedStatementStatistics(String sql) { + this.sql = sql; + } + + public int getExecuteUpdateCount() { + return executeUpdateCount; + } + + public int getAddBatchCount() { + return addBatchCount; + } + + public int getExecuteBatchCount() { + return executeBatchCount; + } + } + + private final Map preparedStatementStatisticsMap = new HashMap<>(); + + @Override + public Connection getConnection() throws SQLException { + return spy( super.getConnection() ); + } + + private Connection spy(Connection connection) { + if ( new MockUtil().isMock( connection ) ) { + return connection; + } + Connection connectionSpy = Mockito.spy( connection ); + try { + doAnswer( invocation -> { + PreparedStatement statement = (PreparedStatement) invocation.callRealMethod(); + PreparedStatement statementSpy = Mockito.spy( statement ); + preparedStatementStatisticsMap.putIfAbsent( statementSpy, + new PreparedStatementStatistics( + (String) invocation + .getArguments()[0] ) + ); + + doAnswer( _invocation -> { + Object mock = _invocation.getMock(); + preparedStatementStatisticsMap.get( mock ).executeUpdateCount++; + return _invocation.callRealMethod(); + } ).when( statementSpy ).executeUpdate(); + + doAnswer( _invocation -> { + Object mock = _invocation.getMock(); + preparedStatementStatisticsMap.get( mock ).addBatchCount++; + return _invocation.callRealMethod(); + } ).when( statementSpy ).addBatch(); + + doAnswer( _invocation -> { + Object mock = _invocation.getMock(); + preparedStatementStatisticsMap.get( mock ).executeBatchCount++; + return _invocation.callRealMethod(); + } ).when( statementSpy ).executeBatch(); + + return statementSpy; + } ).when( connectionSpy ).prepareStatement( anyString() ); + } + catch ( SQLException e ) { + e.printStackTrace(); + } + return connectionSpy; + } + + public void clear() { + preparedStatementStatisticsMap.clear(); + } + + public Map getPreparedStatementStatistics() { + Map statisticsMap = new HashMap<>(); + preparedStatementStatisticsMap.values() + .stream() + .forEach( stats -> statisticsMap.put( stats.sql, stats ) ); + return statisticsMap; + } +}