Change PreparedStatementSpyConnectionProvider so that it works with any given ConnectionProvider

This commit is contained in:
Vlad Mihalcea 2016-06-14 13:52:35 +03:00
parent 6142f92d2f
commit e9b48a881d
2 changed files with 121 additions and 13 deletions

View File

@ -8,6 +8,7 @@ package org.hibernate.test.jdbc.internal;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import javax.persistence.Entity;
import javax.persistence.Id;
@ -19,6 +20,7 @@ import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.hibernate.test.util.jdbc.PreparedStatementSpyConnectionProvider;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -68,7 +70,8 @@ public class SessionJdbcBatchTest
session.getTransaction().commit();
session.close();
}
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" );
PreparedStatement preparedStatement = connectionProvider.getPreparedStatement(
"insert into Event (name, id) values (?, ?)" );
verify( preparedStatement, times( 5 ) ).addBatch();
verify( preparedStatement, times( 3 ) ).executeBatch();
}
@ -103,7 +106,9 @@ public class SessionJdbcBatchTest
session.getTransaction().commit();
session.close();
}
preparedStatement = connectionProvider.getPreparedStatement( "insert into Event (name, id) values (?, ?)" );
List<PreparedStatement> preparedStatements = connectionProvider.getPreparedStatements();
assertEquals(1, preparedStatements.size());
preparedStatement = preparedStatements.get( 0 );
verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 3 )).executeBatch();
}

View File

@ -9,10 +9,19 @@ package org.hibernate.test.util.jdbc;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.hibernate.engine.jdbc.connections.internal.DriverManagerConnectionProviderImpl;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.engine.jdbc.connections.internal.ConnectionProviderInitiator;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.service.spi.Configurable;
import org.hibernate.service.spi.ServiceRegistryAwareService;
import org.hibernate.service.spi.ServiceRegistryImplementor;
import org.mockito.Mockito;
import org.mockito.internal.util.MockUtil;
@ -21,16 +30,65 @@ import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
/**
* This {@link ConnectionProvider} extends any other ConnectionProvider that would be used by default taken the current configuration properties, and it
* intercept the underlying {@link PreparedStatement} method calls.
*
* @author Vlad Mihalcea
*/
public class PreparedStatementSpyConnectionProvider extends
DriverManagerConnectionProviderImpl {
public class PreparedStatementSpyConnectionProvider implements
ConnectionProvider,
Configurable,
ServiceRegistryAwareService {
private final Map<String, PreparedStatement> preparedStatementStatisticsMap = new HashMap<>();
private ServiceRegistryImplementor serviceRegistry;
private ConnectionProvider connectionProvider;
private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();
@Override
public void configure(Map configurationValues) {
@SuppressWarnings("unchecked")
Map<String, Object> settings = new HashMap<>( configurationValues );
settings.remove( AvailableSettings.CONNECTION_PROVIDER );
connectionProvider = ConnectionProviderInitiator.INSTANCE.initiateService(
settings,
serviceRegistry
);
if ( connectionProvider instanceof Configurable ) {
Configurable configurableConnectionProvider = (Configurable) connectionProvider;
configurableConnectionProvider.configure( settings );
}
}
@Override
public void injectServices(ServiceRegistryImplementor serviceRegistry) {
this.serviceRegistry = serviceRegistry;
}
@Override
public Connection getConnection() throws SQLException {
return spy( super.getConnection() );
return spy( connectionProvider.getConnection() );
}
@Override
public void closeConnection(Connection conn) throws SQLException {
connectionProvider.closeConnection( conn );
}
@Override
public boolean supportsAggressiveRelease() {
return connectionProvider.supportsAggressiveRelease();
}
@Override
public boolean isUnwrappableAs(Class unwrapType) {
return connectionProvider.isUnwrappableAs( unwrapType );
}
@Override
public <T> T unwrap(Class<T> unwrapType) {
return connectionProvider.unwrap( unwrapType );
}
private Connection spy(Connection connection) {
@ -43,7 +101,7 @@ public class PreparedStatementSpyConnectionProvider extends
PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
PreparedStatement statementSpy = Mockito.spy( statement );
String sql = (String) invocation.getArguments()[0];
preparedStatementStatisticsMap.put( sql, statementSpy );
preparedStatementMap.put( statementSpy, sql );
return statementSpy;
} ).when( connectionSpy ).prepareStatement( anyString() );
}
@ -53,12 +111,57 @@ public class PreparedStatementSpyConnectionProvider extends
return connectionSpy;
}
/**
* Clears the recorded PreparedStatements and reset the associated Mocks.
*/
public void clear() {
preparedStatementStatisticsMap.values().forEach( Mockito::reset );
preparedStatementStatisticsMap.clear();
preparedStatementMap.keySet().forEach( Mockito::reset );
preparedStatementMap.clear();
}
/**
* Get one and only one PreparedStatement associated to the given SQL statement.
*
* @param sql SQL statement.
*
* @return matching PreparedStatement.
*
* @throws IllegalArgumentException If there is no matching PreparedStatement or multiple instances, an exception is being thrown.
*/
public PreparedStatement getPreparedStatement(String sql) {
return preparedStatementStatisticsMap.get( sql );
List<PreparedStatement> preparedStatements = getPreparedStatements( sql );
if ( preparedStatements.isEmpty() ) {
throw new IllegalArgumentException(
"There is no PreparedStatement for this SQL statement " + sql );
}
else if ( preparedStatements.size() > 1 ) {
throw new IllegalArgumentException( "There are " + preparedStatements
.size() + " PreparedStatements for this SQL statement " + sql );
}
return preparedStatements.get( 0 );
}
/**
* Get the PreparedStatements that are associated to the following SQL statement.
*
* @param sql SQL statement.
*
* @return list of recorded PreparedStatements matching the SQL statement.
*/
public List<PreparedStatement> getPreparedStatements(String sql) {
return preparedStatementMap.entrySet()
.stream()
.filter( entry -> entry.getValue().equals( sql ) )
.map( Map.Entry::getKey )
.collect( Collectors.toList() );
}
/**
* Get the PreparedStatements that were executed since the last clear operation.
*
* @return list of recorded PreparedStatements.
*/
public List<PreparedStatement> getPreparedStatements() {
return new ArrayList<>( preparedStatementMap.keySet() );
}
}