Improve the PreparedStatement assertion mechanism to rely on Mockito solely

This commit is contained in:
Vlad Mihalcea 2016-06-14 12:02:14 +03:00
parent be93105e9a
commit 6142f92d2f
5 changed files with 86 additions and 152 deletions

View File

@ -6,6 +6,8 @@
*/ */
package org.hibernate.test.jdbc.internal; package org.hibernate.test.jdbc.internal;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Map; import java.util.Map;
import javax.persistence.Entity; import javax.persistence.Entity;
import javax.persistence.Id; import javax.persistence.Id;
@ -14,10 +16,11 @@ import org.hibernate.Session;
import org.hibernate.cfg.AvailableSettings; import org.hibernate.cfg.AvailableSettings;
import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase; import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.hibernate.test.util.JdbcStatisticsConnectionProvider; import org.hibernate.test.util.jdbc.PreparedStatementSpyConnectionProvider;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/** /**
* @author Vlad Mihalcea * @author Vlad Mihalcea
@ -25,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class SessionJdbcBatchTest public class SessionJdbcBatchTest
extends BaseNonConfigCoreFunctionalTestCase { extends BaseNonConfigCoreFunctionalTestCase {
private JdbcStatisticsConnectionProvider connectionProvider = new JdbcStatisticsConnectionProvider(); private PreparedStatementSpyConnectionProvider connectionProvider = new PreparedStatementSpyConnectionProvider();
@Override @Override
protected Class<?>[] getAnnotatedClasses() { protected Class<?>[] getAnnotatedClasses() {
@ -54,26 +57,25 @@ public class SessionJdbcBatchTest
private long id; private long id;
@Test @Test
public void testSessionFactorySetting() { public void testSessionFactorySetting() throws SQLException {
Session session = sessionFactory().openSession(); Session session = sessionFactory().openSession();
session.beginTransaction(); session.beginTransaction();
try { try {
addEvents( session ); addEvents( session );
} }
finally { finally {
connectionProvider.getPreparedStatementStatistics().clear(); connectionProvider.clear();
session.getTransaction().commit(); session.getTransaction().commit();
session.close(); session.close();
} }
JdbcStatisticsConnectionProvider.PreparedStatementStatistics statementStatistics = PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" );
connectionProvider.getPreparedStatementStatistics().get( verify(preparedStatement, times( 5 )).addBatch();
"insert into Event (name, id) values (?, ?)" ); verify(preparedStatement, times( 3 )).executeBatch();
assertEquals( 5, statementStatistics.getAddBatchCount() );
assertEquals( 3, statementStatistics.getExecuteBatchCount() );
} }
@Test @Test
public void testSessionSettingOverridesSessionFactorySetting() { public void testSessionSettingOverridesSessionFactorySetting()
throws SQLException {
Session session = sessionFactory().openSession(); Session session = sessionFactory().openSession();
session.setJdbcBatchSize( 3 ); session.setJdbcBatchSize( 3 );
session.beginTransaction(); session.beginTransaction();
@ -85,11 +87,10 @@ public class SessionJdbcBatchTest
session.getTransaction().commit(); session.getTransaction().commit();
session.close(); session.close();
} }
JdbcStatisticsConnectionProvider.PreparedStatementStatistics statementStatistics =
connectionProvider.getPreparedStatementStatistics().get( PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" );
"insert into Event (name, id) values (?, ?)" ); verify(preparedStatement, times( 5 )).addBatch();
assertEquals( 5, statementStatistics.getAddBatchCount() ); verify(preparedStatement, times( 2 )).executeBatch();
assertEquals( 2, statementStatistics.getExecuteBatchCount() );
session = sessionFactory().openSession(); session = sessionFactory().openSession();
session.setJdbcBatchSize( null ); session.setJdbcBatchSize( null );
@ -102,11 +103,9 @@ public class SessionJdbcBatchTest
session.getTransaction().commit(); session.getTransaction().commit();
session.close(); session.close();
} }
statementStatistics = preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" );
connectionProvider.getPreparedStatementStatistics().get( verify(preparedStatement, times( 5 )).addBatch();
"insert into Event (name, id) values (?, ?)" ); verify(preparedStatement, times( 3 )).executeBatch();
assertEquals( 5, statementStatistics.getAddBatchCount() );
assertEquals( 3, statementStatistics.getExecuteBatchCount() );
} }
private void addEvents(Session session) { private void addEvents(Session session) {
@ -118,20 +117,6 @@ public class SessionJdbcBatchTest
} }
} }
@Test
public void testSessionJdbcBatchOverridesSessionFactorySetting() {
Session session = sessionFactory().openSession();
session.beginTransaction();
try {
}
finally {
session.getTransaction().commit();
session.close();
}
}
@Entity(name = "Event") @Entity(name = "Event")
public static class Event { public static class Event {

View File

@ -10,7 +10,7 @@ import org.hibernate.dialect.SQLServer2005Dialect;
import org.hibernate.testing.RequiresDialect; import org.hibernate.testing.RequiresDialect;
import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase; import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.hibernate.test.util.SQLStatementInterceptor; import org.hibernate.test.util.jdbc.SQLStatementInterceptor;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;

View File

@ -1,115 +0,0 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.test.util;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;
import org.hibernate.engine.jdbc.connections.internal.DriverManagerConnectionProviderImpl;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
/**
* @author Vlad Mihalcea
*/
public class JdbcStatisticsConnectionProvider extends
DriverManagerConnectionProviderImpl {
public static class PreparedStatementStatistics {
private final String sql;
private int executeUpdateCount;
private int addBatchCount;
private int executeBatchCount;
public PreparedStatementStatistics(String sql) {
this.sql = sql;
}
public int getExecuteUpdateCount() {
return executeUpdateCount;
}
public int getAddBatchCount() {
return addBatchCount;
}
public int getExecuteBatchCount() {
return executeBatchCount;
}
}
private final Map<PreparedStatement, PreparedStatementStatistics> preparedStatementStatisticsMap = new HashMap<>();
@Override
public Connection getConnection() throws SQLException {
return spy( super.getConnection() );
}
private Connection spy(Connection connection) {
if ( new MockUtil().isMock( connection ) ) {
return connection;
}
Connection connectionSpy = Mockito.spy( connection );
try {
doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
preparedStatementStatisticsMap.putIfAbsent( statementSpy,
new PreparedStatementStatistics(
(String) invocation
.getArguments()[0] )
);
doAnswer( _invocation -> {
Object mock = _invocation.getMock();
preparedStatementStatisticsMap.get( mock ).executeUpdateCount++;
return _invocation.callRealMethod();
} ).when( statementSpy ).executeUpdate();
doAnswer( _invocation -> {
Object mock = _invocation.getMock();
preparedStatementStatisticsMap.get( mock ).addBatchCount++;
return _invocation.callRealMethod();
} ).when( statementSpy ).addBatch();
doAnswer( _invocation -> {
Object mock = _invocation.getMock();
preparedStatementStatisticsMap.get( mock ).executeBatchCount++;
return _invocation.callRealMethod();
} ).when( statementSpy ).executeBatch();
return statementSpy;
} ).when( connectionSpy ).prepareStatement( anyString() );
}
catch ( SQLException e ) {
e.printStackTrace();
}
return connectionSpy;
}
public void clear() {
preparedStatementStatisticsMap.clear();
}
public Map<String, PreparedStatementStatistics> getPreparedStatementStatistics() {
Map<String, PreparedStatementStatistics> statisticsMap = new HashMap<>();
preparedStatementStatisticsMap.values()
.stream()
.forEach( stats -> statisticsMap.put( stats.sql, stats ) );
return statisticsMap;
}
}

View File

@ -0,0 +1,64 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.test.util.jdbc;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;
import org.hibernate.engine.jdbc.connections.internal.DriverManagerConnectionProviderImpl;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
/**
* @author Vlad Mihalcea
*/
public class PreparedStatementSpyConnectionProvider extends
DriverManagerConnectionProviderImpl {
private final Map<String, PreparedStatement> preparedStatementStatisticsMap = new HashMap<>();
@Override
public Connection getConnection() throws SQLException {
return spy( super.getConnection() );
}
private Connection spy(Connection connection) {
if ( new MockUtil().isMock( connection ) ) {
return connection;
}
Connection connectionSpy = Mockito.spy( connection );
try {
doAnswer( invocation -> {
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0];
preparedStatementStatisticsMap.put( sql, statementSpy );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( anyString() );
}
catch ( SQLException e ) {
throw new IllegalArgumentException( e );
}
return connectionSpy;
}
public void clear() {
preparedStatementStatisticsMap.values().forEach( Mockito::reset );
preparedStatementStatisticsMap.clear();
}
public PreparedStatement getPreparedStatement(String sql) {
return preparedStatementStatisticsMap.get( sql );
}
}

View File

@ -1,4 +1,4 @@
package org.hibernate.test.util; package org.hibernate.test.util.jdbc;
import java.util.LinkedList; import java.util.LinkedList;