From 96d4639e6cd6ecedb61f228b22de2ad7fa324b0e Mon Sep 17 00:00:00 2001 From: Scott Marlow Date: Thu, 17 Dec 2015 10:17:07 -0500 Subject: [PATCH] HHH-10384 Fix thread safety issues in thread local optimiser --- .../PooledLoThreadLocalOptimizer.java | 35 ++++++++----------- .../id/enhanced/OptimizerUnitTest.java | 10 +++--- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/hibernate-core/src/main/java/org/hibernate/id/enhanced/PooledLoThreadLocalOptimizer.java b/hibernate-core/src/main/java/org/hibernate/id/enhanced/PooledLoThreadLocalOptimizer.java index 234b588036..03525773b4 100644 --- a/hibernate-core/src/main/java/org/hibernate/id/enhanced/PooledLoThreadLocalOptimizer.java +++ b/hibernate-core/src/main/java/org/hibernate/id/enhanced/PooledLoThreadLocalOptimizer.java @@ -7,8 +7,8 @@ package org.hibernate.id.enhanced; import java.io.Serializable; +import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import org.hibernate.HibernateException; import org.hibernate.id.IntegralDataTypeHolder; @@ -65,49 +65,38 @@ public class PooledLoThreadLocalOptimizer extends AbstractOptimizer { } synchronized (this) { - final GenerationState generationState = locateGenerationState(callback.getTenantIdentifier()); + final GenerationState generationState = locateGenerationState( callback.getTenantIdentifier() ); - if (generationState.lastSourceValue == null - || !generationState.value.lt(generationState.upperLimitValue)) { + if ( generationState.lastSourceValue == null + || !generationState.value.lt( generationState.upperLimitValue )) { generationState.lastSourceValue = callback.getNextValue(); - generationState.upperLimitValue = generationState.lastSourceValue.copy().add(incrementSize); + generationState.upperLimitValue = generationState.lastSourceValue.copy().add( incrementSize ); generationState.value = generationState.lastSourceValue.copy(); // handle cases where initial-value is less that one (hsqldb for instance). - while (generationState.value.lt(1)) { + while (generationState.value.lt( 1 )) { generationState.value.increment(); } } - if(callback.getTenantIdentifier() != null) { - return generationState.value.makeValueThenIncrement(); - } else { - if ( local == null ) { - local = new GenerationState(); - localAssignedIds.set( local ); - } - local.upperLimitValue = generationState.upperLimitValue.copy(); - local.value = generationState.value.copy(); - local.lastSourceValue = generationState.lastSourceValue.copy(); - generationState.value = generationState.upperLimitValue.copy(); - return local.value.makeValueThenIncrement(); - } + return generationState.value.makeValueThenIncrement(); } } - private GenerationState noTenantState; private Map tenantSpecificState; private final ThreadLocal localAssignedIds = new ThreadLocal(); private GenerationState locateGenerationState(String tenantIdentifier) { if ( tenantIdentifier == null ) { + GenerationState noTenantState = localAssignedIds.get(); if ( noTenantState == null ) { noTenantState = new GenerationState(); + localAssignedIds.set(noTenantState); } return noTenantState; } else { GenerationState state; if ( tenantSpecificState == null ) { - tenantSpecificState = new ConcurrentHashMap(); + tenantSpecificState = new HashMap(); state = new GenerationState(); tenantSpecificState.put( tenantIdentifier, state ); } @@ -122,13 +111,17 @@ public class PooledLoThreadLocalOptimizer extends AbstractOptimizer { } } + // for Hibernate testsuite use only private GenerationState noTenantGenerationState() { + GenerationState noTenantState = locateGenerationState( null ); + if ( noTenantState == null ) { throw new IllegalStateException( "Could not locate previous generation state for no-tenant" ); } return noTenantState; } + // for Hibernate testsuite use only @Override public IntegralDataTypeHolder getLastSourceValue() { return noTenantGenerationState().lastSourceValue; diff --git a/hibernate-core/src/test/java/org/hibernate/id/enhanced/OptimizerUnitTest.java b/hibernate-core/src/test/java/org/hibernate/id/enhanced/OptimizerUnitTest.java index f4593dfc1a..c9b0ef2239 100644 --- a/hibernate-core/src/test/java/org/hibernate/id/enhanced/OptimizerUnitTest.java +++ b/hibernate-core/src/test/java/org/hibernate/id/enhanced/OptimizerUnitTest.java @@ -261,8 +261,8 @@ public class OptimizerUnitTest extends BaseUnitTestCase { @Test public void testBasicPooledThreadLocalLoOptimizerUsage() { - final SourceMock sequence = new SourceMock( 1, 5000 ); // pass 5000 to match default for PooledThreadLocalLoOptimizer.THREAD_LOCAL_BLOCK_SIZE - final Optimizer optimizer = buildPooledThreadLocalLoOptimizer( 1, 5000 ); + final SourceMock sequence = new SourceMock( 1, 50 ); // pass 5000 to match default for PooledThreadLocalLoOptimizer.THREAD_LOCAL_BLOCK_SIZE + final Optimizer optimizer = buildPooledThreadLocalLoOptimizer( 1, 50 ); assertEquals( 0, sequence.getTimesCalled() ); assertEquals( -1, sequence.getCurrentValue() ); @@ -282,13 +282,13 @@ public class OptimizerUnitTest extends BaseUnitTestCase { assertEquals( 1, sequence.getTimesCalled() ); assertEquals( 1, sequence.getCurrentValue() ); - for( int looper = 0; looper < 5001; looper++) { + for( int looper = 0; looper < 51; looper++) { next = ( Long ) optimizer.generate( sequence ); } - assertEquals( 3 + 5001, next.intValue() ); + assertEquals( 3 + 51, next.intValue() ); assertEquals( 2, sequence.getTimesCalled() ); - assertEquals( 5001, sequence.getCurrentValue() ); + assertEquals( 51, sequence.getCurrentValue() ); }