HHH-10384 Fix thread safety issues in thread local optimiser

This commit is contained in:
Scott Marlow 2015-12-17 10:17:07 -05:00
parent bef14a5890
commit 96d4639e6c
2 changed files with 19 additions and 26 deletions

View File

@ -7,8 +7,8 @@
package org.hibernate.id.enhanced;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.hibernate.HibernateException;
import org.hibernate.id.IntegralDataTypeHolder;
@ -65,49 +65,38 @@ public class PooledLoThreadLocalOptimizer extends AbstractOptimizer {
}
synchronized (this) {
final GenerationState generationState = locateGenerationState(callback.getTenantIdentifier());
final GenerationState generationState = locateGenerationState( callback.getTenantIdentifier() );
if (generationState.lastSourceValue == null
|| !generationState.value.lt(generationState.upperLimitValue)) {
if ( generationState.lastSourceValue == null
|| !generationState.value.lt( generationState.upperLimitValue )) {
generationState.lastSourceValue = callback.getNextValue();
generationState.upperLimitValue = generationState.lastSourceValue.copy().add(incrementSize);
generationState.upperLimitValue = generationState.lastSourceValue.copy().add( incrementSize );
generationState.value = generationState.lastSourceValue.copy();
// handle cases where initial-value is less that one (hsqldb for instance).
while (generationState.value.lt(1)) {
while (generationState.value.lt( 1 )) {
generationState.value.increment();
}
}
if(callback.getTenantIdentifier() != null) {
return generationState.value.makeValueThenIncrement();
} else {
if ( local == null ) {
local = new GenerationState();
localAssignedIds.set( local );
}
local.upperLimitValue = generationState.upperLimitValue.copy();
local.value = generationState.value.copy();
local.lastSourceValue = generationState.lastSourceValue.copy();
generationState.value = generationState.upperLimitValue.copy();
return local.value.makeValueThenIncrement();
}
return generationState.value.makeValueThenIncrement();
}
}
private GenerationState noTenantState;
private Map<String, GenerationState> tenantSpecificState;
private final ThreadLocal<GenerationState> localAssignedIds = new ThreadLocal<GenerationState>();
private GenerationState locateGenerationState(String tenantIdentifier) {
if ( tenantIdentifier == null ) {
GenerationState noTenantState = localAssignedIds.get();
if ( noTenantState == null ) {
noTenantState = new GenerationState();
localAssignedIds.set(noTenantState);
}
return noTenantState;
}
else {
GenerationState state;
if ( tenantSpecificState == null ) {
tenantSpecificState = new ConcurrentHashMap<String, GenerationState>();
tenantSpecificState = new HashMap<String, GenerationState>();
state = new GenerationState();
tenantSpecificState.put( tenantIdentifier, state );
}
@ -122,13 +111,17 @@ public class PooledLoThreadLocalOptimizer extends AbstractOptimizer {
}
}
// for Hibernate testsuite use only
private GenerationState noTenantGenerationState() {
GenerationState noTenantState = locateGenerationState( null );
if ( noTenantState == null ) {
throw new IllegalStateException( "Could not locate previous generation state for no-tenant" );
}
return noTenantState;
}
// for Hibernate testsuite use only
@Override
public IntegralDataTypeHolder getLastSourceValue() {
return noTenantGenerationState().lastSourceValue;

View File

@ -261,8 +261,8 @@ public class OptimizerUnitTest extends BaseUnitTestCase {
@Test
public void testBasicPooledThreadLocalLoOptimizerUsage() {
final SourceMock sequence = new SourceMock( 1, 5000 ); // pass 5000 to match default for PooledThreadLocalLoOptimizer.THREAD_LOCAL_BLOCK_SIZE
final Optimizer optimizer = buildPooledThreadLocalLoOptimizer( 1, 5000 );
final SourceMock sequence = new SourceMock( 1, 50 ); // pass 5000 to match default for PooledThreadLocalLoOptimizer.THREAD_LOCAL_BLOCK_SIZE
final Optimizer optimizer = buildPooledThreadLocalLoOptimizer( 1, 50 );
assertEquals( 0, sequence.getTimesCalled() );
assertEquals( -1, sequence.getCurrentValue() );
@ -282,13 +282,13 @@ public class OptimizerUnitTest extends BaseUnitTestCase {
assertEquals( 1, sequence.getTimesCalled() );
assertEquals( 1, sequence.getCurrentValue() );
for( int looper = 0; looper < 5001; looper++) {
for( int looper = 0; looper < 51; looper++) {
next = ( Long ) optimizer.generate( sequence );
}
assertEquals( 3 + 5001, next.intValue() );
assertEquals( 3 + 51, next.intValue() );
assertEquals( 2, sequence.getTimesCalled() );
assertEquals( 5001, sequence.getCurrentValue() );
assertEquals( 51, sequence.getCurrentValue() );
}