diff --git a/hibernate-core/src/main/java/org/hibernate/engine/jdbc/connections/internal/ConnectionProviderInitiator.java b/hibernate-core/src/main/java/org/hibernate/engine/jdbc/connections/internal/ConnectionProviderInitiator.java index 055a91461c..c2599a5b89 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/jdbc/connections/internal/ConnectionProviderInitiator.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/jdbc/connections/internal/ConnectionProviderInitiator.java @@ -4,6 +4,7 @@ */ package org.hibernate.engine.jdbc.connections.internal; +import java.lang.reflect.InvocationTargetException; import java.sql.Connection; import java.util.Collection; import java.util.HashSet; @@ -19,6 +20,11 @@ import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider; import org.hibernate.internal.CoreLogging; import org.hibernate.internal.CoreMessageLogger; import org.hibernate.internal.util.StringHelper; +import org.hibernate.resource.beans.container.spi.BeanContainer; +import org.hibernate.resource.beans.internal.FallbackBeanInstanceProducer; +import org.hibernate.resource.beans.internal.Helper; +import org.hibernate.resource.beans.spi.BeanInstanceProducer; +import org.hibernate.resource.beans.spi.ManagedBeanRegistry; import org.hibernate.service.spi.ServiceRegistryImplementor; import static java.sql.Connection.TRANSACTION_NONE; @@ -102,6 +108,7 @@ public class ConnectionProviderInitiator implements StandardServiceInitiator providerClass ) { LOG.instantiatingExplicitConnectionProvider( providerClass.getName() ); - return instantiateExplicitConnectionProvider( providerClass ); + return instantiateExplicitConnectionProvider( providerClass, beanContainer ); } else { final String providerName = nullIfEmpty( explicitSetting.toString() ); if ( providerName != null ) { - return instantiateNamedConnectionProvider(providerName, strategySelector); + return instantiateNamedConnectionProvider(providerName, strategySelector, beanContainer); } } } - return instantiateConnectionProvider( configurationValues, strategySelector ); + return instantiateConnectionProvider( configurationValues, strategySelector, beanContainer ); } - private ConnectionProvider instantiateNamedConnectionProvider(String providerName, StrategySelector strategySelector) { + private ConnectionProvider instantiateNamedConnectionProvider(String providerName, StrategySelector strategySelector, BeanContainer beanContainer) { LOG.instantiatingExplicitConnectionProvider( providerName ); final Class providerClass = strategySelector.selectStrategyImplementor( ConnectionProvider.class, providerName ); try { - return instantiateExplicitConnectionProvider( providerClass ); + return instantiateExplicitConnectionProvider( providerClass, beanContainer ); } catch (Exception e) { throw new HibernateException( @@ -140,7 +147,7 @@ public class ConnectionProviderInitiator implements StandardServiceInitiator configurationValues, StrategySelector strategySelector) { + Map configurationValues, StrategySelector strategySelector, BeanContainer beanContainer) { if ( configurationValues.containsKey( DATASOURCE ) ) { return new DatasourceConnectionProviderImpl(); } @@ -149,9 +156,9 @@ public class ConnectionProviderInitiator implements StandardServiceInitiator B produceBeanInstance(Class beanType) { + return (B) noAppropriateConnectionProvider(); + } + + @Override + public B produceBeanInstance(String name, Class beanType) { + return (B) noAppropriateConnectionProvider(); + } + + } + ).getBeanInstance(); + } + else { + return noAppropriateConnectionProvider(); + } + } } + private ConnectionProvider noAppropriateConnectionProvider() { + LOG.noAppropriateConnectionProvider(); + return new UserSuppliedConnectionProviderImpl(); + } + private Class getSingleRegisteredProvider(StrategySelector strategySelector) { final Collection> implementors = strategySelector.getRegisteredStrategyImplementors( ConnectionProvider.class ); @@ -190,9 +233,28 @@ public class ConnectionProviderInitiator implements StandardServiceInitiator providerClass) { + private ConnectionProvider instantiateExplicitConnectionProvider(Class providerClass, BeanContainer beanContainer) { try { - return (ConnectionProvider) providerClass.newInstance(); + if ( beanContainer != null ) { + return (ConnectionProvider) beanContainer.getBean( + providerClass, + new BeanContainer.LifecycleOptions() { + @Override + public boolean canUseCachedReferences() { + return true; + } + + @Override + public boolean useJpaCompliantCreation() { + return true; + } + }, + FallbackBeanInstanceProducer.INSTANCE + ).getBeanInstance(); + } + else { + return (ConnectionProvider) providerClass.getConstructor().newInstance(); + } } catch (Exception e) { throw new HibernateException( "Could not instantiate connection provider [" + providerClass.getName() + "]", e ); @@ -201,7 +263,7 @@ public class ConnectionProviderInitiator implements StandardServiceInitiator createSettings() { + Map settings = new HashMap<>(); + settings.put( AvailableSettings.ALLOW_EXTENSIONS_IN_CDI, "true" ); + settings.put( AvailableSettings.BEAN_CONTAINER, new BeanContainer() { + @Override + @SuppressWarnings("unchecked") + public ContainedBean getBean( + Class beanType, + LifecycleOptions lifecycleOptions, + BeanInstanceProducer fallbackProducer) { + return () -> (B) ( beanType == DummyConnectionProvider.class ? + dummyConnectionProvider : fallbackProducer.produceBeanInstance( beanType ) ); + } + + @Override + public ContainedBean getBean( + String name, + Class beanType, + LifecycleOptions lifecycleOptions, + BeanInstanceProducer fallbackProducer) { + return () -> (B) fallbackProducer.produceBeanInstance( beanType ); + } + + @Override + public void stop() { + + } + } ); + return settings; + } + + @Test + public void testProviderFromBeanContainerInUse() { + Map settings = createSettings(); + settings.putIfAbsent( CONNECTION_PROVIDER, DummyConnectionProvider.class.getName() ); + try ( ServiceRegistry serviceRegistry = ServiceRegistryUtil.serviceRegistryBuilder() + .applySettings( settings ).build() ) { + ConnectionProvider providerInUse = serviceRegistry.getService( ConnectionProvider.class ); + assertSame( dummyConnectionProvider, providerInUse ); + } + } + + public static class DummyConnectionProvider implements ConnectionProvider { + + @Override + public boolean isUnwrappableAs(Class unwrapType) { + return false; + } + + @Override + public T unwrap(Class unwrapType) { + return null; + } + + @Override + public Connection getConnection() throws SQLException { + return null; + } + + @Override + public void closeConnection(Connection connection) throws SQLException { + + } + + @Override + public boolean supportsAggressiveRelease() { + return false; + } + }; +}