Change PreparedStatementSpyConnectionProvider so that it works with any given ConnectionProvider

This commit is contained in:
Vlad Mihalcea 2016-06-14 13:52:35 +03:00
parent 6142f92d2f
commit e9b48a881d
2 changed files with 121 additions and 13 deletions

View File

@ -8,6 +8,7 @@ package org.hibernate.test.jdbc.internal;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.List;
import java.util.Map; import java.util.Map;
import javax.persistence.Entity; import javax.persistence.Entity;
import javax.persistence.Id; import javax.persistence.Id;
@ -19,6 +20,7 @@ import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.hibernate.test.util.jdbc.PreparedStatementSpyConnectionProvider; import org.hibernate.test.util.jdbc.PreparedStatementSpyConnectionProvider;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -68,9 +70,10 @@ public class SessionJdbcBatchTest
session.getTransaction().commit(); session.getTransaction().commit();
session.close(); session.close();
} }
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" ); PreparedStatement preparedStatement = connectionProvider.getPreparedStatement(
verify(preparedStatement, times( 5 )).addBatch(); "insert into Event (name, id) values (?, ?)" );
verify(preparedStatement, times( 3 )).executeBatch(); verify( preparedStatement, times( 5 ) ).addBatch();
verify( preparedStatement, times( 3 ) ).executeBatch();
} }
@Test @Test
@ -103,7 +106,9 @@ public class SessionJdbcBatchTest
session.getTransaction().commit(); session.getTransaction().commit();
session.close(); session.close();
} }
preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" ); List<PreparedStatement> preparedStatements = connectionProvider.getPreparedStatements();
assertEquals(1, preparedStatements.size());
preparedStatement = preparedStatements.get( 0 );
verify(preparedStatement, times( 5 )).addBatch(); verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 3 )).executeBatch(); verify(preparedStatement, times( 3 )).executeBatch();
} }

View File

@ -9,10 +9,19 @@ package org.hibernate.test.util.jdbc;
import java.sql.Connection; import java.sql.Connection;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
import org.hibernate.engine.jdbc.connections.internal.DriverManagerConnectionProviderImpl; import org.hibernate.cfg.AvailableSettings;
import org.hibernate.engine.jdbc.connections.internal.ConnectionProviderInitiator;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.service.spi.Configurable;
import org.hibernate.service.spi.ServiceRegistryAwareService;
import org.hibernate.service.spi.ServiceRegistryImplementor;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil; import org.mockito.internal.util.MockUtil;
@ -21,16 +30,65 @@ import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
/** /**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
* intercept the underlying {@link PreparedStatement} method calls.
*
* @author Vlad Mihalcea * @author Vlad Mihalcea
*/ */
public class PreparedStatementSpyConnectionProvider extends public class PreparedStatementSpyConnectionProvider implements
DriverManagerConnectionProviderImpl { ConnectionProvider,
Configurable,
ServiceRegistryAwareService {
private final Map<String, PreparedStatement> preparedStatementStatisticsMap = new HashMap<>(); private ServiceRegistryImplementor serviceRegistry;
private ConnectionProvider connectionProvider;
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
@Override
public void configure(Map configurationValues) {
@SuppressWarnings("unchecked")
Map<String, Object> settings = new HashMap<>( configurationValues );
settings.remove( AvailableSettings.CONNECTION_PROVIDER );
connectionProvider = ConnectionProviderInitiator.INSTANCE.initiateService(
settings,
serviceRegistry
);
if ( connectionProvider instanceof Configurable ) {
Configurable configurableConnectionProvider = (Configurable) connectionProvider;
configurableConnectionProvider.configure( settings );
}
}
@Override
public void injectServices(ServiceRegistryImplementor serviceRegistry) {
this.serviceRegistry = serviceRegistry;
}
@Override @Override
public Connection getConnection() throws SQLException { public Connection getConnection() throws SQLException {
return spy( super.getConnection() ); return spy( connectionProvider.getConnection() );
}
@Override
public void closeConnection(Connection conn) throws SQLException {
connectionProvider.closeConnection( conn );
}
@Override
public boolean supportsAggressiveRelease() {
return connectionProvider.supportsAggressiveRelease();
}
@Override
public boolean isUnwrappableAs(Class unwrapType) {
return connectionProvider.isUnwrappableAs( unwrapType );
}
@Override
public <T> T unwrap(Class<T> unwrapType) {
return connectionProvider.unwrap( unwrapType );
} }
private Connection spy(Connection connection) { private Connection spy(Connection connection) {
@ -43,7 +101,7 @@ public class PreparedStatementSpyConnectionProvider extends
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod(); PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement ); PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0]; String sql = (String) invocation.getArguments()[0];
preparedStatementStatisticsMap.put( sql, statementSpy ); preparedStatementMap.put( statementSpy, sql );
return statementSpy; return statementSpy;
} ).when( connectionSpy ).prepareStatement( anyString() ); } ).when( connectionSpy ).prepareStatement( anyString() );
} }
@ -53,12 +111,57 @@ public class PreparedStatementSpyConnectionProvider extends
return connectionSpy; return connectionSpy;
} }
/**
* Clears the recorded PreparedStatements and reset the associated Mocks.
*/
public void clear() { public void clear() {
preparedStatementStatisticsMap.values().forEach( Mockito::reset ); preparedStatementMap.keySet().forEach( Mockito::reset );
preparedStatementStatisticsMap.clear(); preparedStatementMap.clear();
} }
/**
* Get one and only one PreparedStatement associated to the given SQL statement.
*
* @param sql SQL statement.
*
* @return matching PreparedStatement.
*
* @throws IllegalArgumentException If there is no matching PreparedStatement or multiple instances, an exception is being thrown.
*/
public PreparedStatement getPreparedStatement(String sql) { public PreparedStatement getPreparedStatement(String sql) {
return preparedStatementStatisticsMap.get( sql ); List<PreparedStatement> preparedStatements = getPreparedStatements( sql );
if ( preparedStatements.isEmpty() ) {
throw new IllegalArgumentException(
"There is no PreparedStatement for this SQL statement " + sql );
}
else if ( preparedStatements.size() > 1 ) {
throw new IllegalArgumentException( "There are " + preparedStatements
.size() + " PreparedStatements for this SQL statement " + sql );
}
return preparedStatements.get( 0 );
}
/**
* Get the PreparedStatements that are associated to the following SQL statement.
*
* @param sql SQL statement.
*
* @return list of recorded PreparedStatements matching the SQL statement.
*/
public List<PreparedStatement> getPreparedStatements(String sql) {
return preparedStatementMap.entrySet()
.stream()
.filter( entry -> entry.getValue().equals( sql ) )
.map( Map.Entry::getKey )
.collect( Collectors.toList() );
}
/**
* Get the PreparedStatements that were executed since the last clear operation.
*
* @return list of recorded PreparedStatements.
*/
public List<PreparedStatement> getPreparedStatements() {
return new ArrayList<>( preparedStatementMap.keySet() );
} }
} }