Refactor getOrCreateTag method to prevent transaction suspension for XA transaction compatibility

The getOrCreateTag method previously used a propagation behavior that caused issues with
    XA transactions when using the PostgreSQL JDBC driver. The PGXAConnection does not support
    transaction suspend/resume, which made it incompatible with the existing propagation strategy
    'PROPAGATION_REQUIRES_NEW'.

    This refactor changes the getOrCreateTag logic to perform a lookup/write in a new transaction
    as before, but running in a separate thread, such that the main transaction is not suspended.
    The result is retrieved through a future.

    This change aims to improve compatibility and prevent transaction-related issues when using HAPI-FHIR with
    XA transactions and PostgreSQL.

    Closes #3412
This commit is contained in:
Ibrahim Tallouzi 2024-08-21 16:06:58 +02:00
parent 7ded985577
commit 06c8fd0e31
No known key found for this signature in database
GPG Key ID: 8C7D50B9D591234B
2 changed files with 259 additions and 136 deletions

View File

@ -89,7 +89,6 @@ import ca.uhn.fhir.rest.api.InterceptorInvocationTimingEnum;
import ca.uhn.fhir.rest.api.RestOperationTypeEnum; import ca.uhn.fhir.rest.api.RestOperationTypeEnum;
import ca.uhn.fhir.rest.api.server.RequestDetails; import ca.uhn.fhir.rest.api.server.RequestDetails;
import ca.uhn.fhir.rest.api.server.storage.TransactionDetails; import ca.uhn.fhir.rest.api.server.storage.TransactionDetails;
import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException; import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException;
import ca.uhn.fhir.rest.server.exceptions.UnprocessableEntityException; import ca.uhn.fhir.rest.server.exceptions.UnprocessableEntityException;
import ca.uhn.fhir.rest.server.servlet.ServletRequestDetails; import ca.uhn.fhir.rest.server.servlet.ServletRequestDetails;
@ -107,14 +106,9 @@ import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable; import jakarta.annotation.Nullable;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManager;
import jakarta.persistence.NoResultException; import jakarta.persistence.EntityManagerFactory;
import jakarta.persistence.PersistenceContext; import jakarta.persistence.PersistenceContext;
import jakarta.persistence.PersistenceContextType; import jakarta.persistence.PersistenceContextType;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.Predicate;
import jakarta.persistence.criteria.Root;
import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
@ -137,12 +131,8 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionSynchronization; import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
@ -158,7 +148,6 @@ import java.util.stream.Collectors;
import javax.xml.stream.events.Characters; import javax.xml.stream.events.Characters;
import javax.xml.stream.events.XMLEvent; import javax.xml.stream.events.XMLEvent;
import static java.util.Objects.isNull;
import static java.util.Objects.nonNull; import static java.util.Objects.nonNull;
import static org.apache.commons.collections4.CollectionUtils.isEqualCollection; import static org.apache.commons.collections4.CollectionUtils.isEqualCollection;
import static org.apache.commons.lang3.StringUtils.isBlank; import static org.apache.commons.lang3.StringUtils.isBlank;
@ -182,8 +171,6 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
public static final long INDEX_STATUS_INDEXED = 1L; public static final long INDEX_STATUS_INDEXED = 1L;
public static final long INDEX_STATUS_INDEXING_FAILED = 2L; public static final long INDEX_STATUS_INDEXING_FAILED = 2L;
public static final String NS_JPA_PROFILE = "https://github.com/hapifhir/hapi-fhir/ns/jpa/profile"; public static final String NS_JPA_PROFILE = "https://github.com/hapifhir/hapi-fhir/ns/jpa/profile";
// total attempts to do a tag transaction
private static final int TOTAL_TAG_READ_ATTEMPTS = 10;
private static final Logger ourLog = LoggerFactory.getLogger(BaseHapiFhirDao.class); private static final Logger ourLog = LoggerFactory.getLogger(BaseHapiFhirDao.class);
private static boolean ourValidationDisabledForUnitTest; private static boolean ourValidationDisabledForUnitTest;
private static boolean ourDisableIncrementOnUpdateForUnitTest = false; private static boolean ourDisableIncrementOnUpdateForUnitTest = false;
@ -254,12 +241,16 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
@Autowired(required = false) @Autowired(required = false)
private IFulltextSearchSvc myFulltextSearchSvc; private IFulltextSearchSvc myFulltextSearchSvc;
@Autowired
protected ResourceHistoryCalculator myResourceHistoryCalculator;
@Autowired @Autowired
private PlatformTransactionManager myTransactionManager; private PlatformTransactionManager myTransactionManager;
@Autowired @Autowired
protected ResourceHistoryCalculator myResourceHistoryCalculator; private EntityManagerFactory myEntityManagerFactory;
protected TagDefinitionDao tagDefinitionDao;
protected final CodingSpy myCodingSpy = new CodingSpy(); protected final CodingSpy myCodingSpy = new CodingSpy();
@VisibleForTesting @VisibleForTesting
@ -484,7 +475,8 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
if (retVal == null) { if (retVal == null) {
// actual DB hit(s) happen here // actual DB hit(s) happen here
retVal = getOrCreateTag(theTagType, theScheme, theTerm, theLabel, theVersion, theUserSelected); retVal = tagDefinitionDao.getOrCreateTag(
theTagType, theScheme, theTerm, theLabel, theVersion, theUserSelected);
TransactionSynchronization sync = new AddTagDefinitionToCacheAfterCommitSynchronization(key, retVal); TransactionSynchronization sync = new AddTagDefinitionToCacheAfterCommitSynchronization(key, retVal);
TransactionSynchronizationManager.registerSynchronization(sync); TransactionSynchronizationManager.registerSynchronization(sync);
@ -496,124 +488,6 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
return retVal; return retVal;
} }
/**
* Gets the tag defined by the fed in values, or saves it if it does not
* exist.
* <p>
* Can also throw an InternalErrorException if something bad happens.
*/
private TagDefinition getOrCreateTag(
TagTypeEnum theTagType,
String theScheme,
String theTerm,
String theLabel,
String theVersion,
Boolean theUserSelected) {
TypedQuery<TagDefinition> q = buildTagQuery(theTagType, theScheme, theTerm, theVersion, theUserSelected);
q.setMaxResults(1);
TransactionTemplate template = new TransactionTemplate(myTransactionManager);
template.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
// this transaction will attempt to get or create the tag,
// repeating (on any failure) 10 times.
// if it fails more than this, we will throw exceptions
TagDefinition retVal;
int count = 0;
HashSet<Throwable> throwables = new HashSet<>();
do {
try {
retVal = template.execute(new TransactionCallback<TagDefinition>() {
// do the actual DB call(s) to read and/or write the values
private TagDefinition readOrCreate() {
TagDefinition val;
try {
val = q.getSingleResult();
} catch (NoResultException e) {
val = new TagDefinition(theTagType, theScheme, theTerm, theLabel);
val.setVersion(theVersion);
val.setUserSelected(theUserSelected);
myEntityManager.persist(val);
}
return val;
}
@Override
public TagDefinition doInTransaction(TransactionStatus status) {
TagDefinition tag = null;
try {
tag = readOrCreate();
} catch (Exception ex) {
// log any exceptions - just in case
// they may be signs of things to come...
ourLog.warn(
"Tag read/write failed: "
+ ex.getMessage() + ". "
+ "This is not a failure on its own, "
+ "but could be useful information in the result of an actual failure.",
ex);
throwables.add(ex);
}
return tag;
}
});
} catch (Exception ex) {
// transaction template can fail if connections to db are exhausted and/or timeout
ourLog.warn(
"Transaction failed with: {}. Transaction will rollback and be reattempted.", ex.getMessage());
retVal = null;
}
count++;
} while (retVal == null && count < TOTAL_TAG_READ_ATTEMPTS);
if (retVal == null) {
// if tag is still null,
// something bad must be happening
// - throw
String msg = throwables.stream().map(Throwable::getMessage).collect(Collectors.joining(", "));
throw new InternalErrorException(Msg.code(2023)
+ "Tag get/create failed after "
+ TOTAL_TAG_READ_ATTEMPTS
+ " attempts with error(s): "
+ msg);
}
return retVal;
}
private TypedQuery<TagDefinition> buildTagQuery(
TagTypeEnum theTagType, String theScheme, String theTerm, String theVersion, Boolean theUserSelected) {
CriteriaBuilder builder = myEntityManager.getCriteriaBuilder();
CriteriaQuery<TagDefinition> cq = builder.createQuery(TagDefinition.class);
Root<TagDefinition> from = cq.from(TagDefinition.class);
List<Predicate> predicates = new ArrayList<>();
predicates.add(builder.and(
builder.equal(from.get("myTagType"), theTagType), builder.equal(from.get("myCode"), theTerm)));
predicates.add(
isBlank(theScheme)
? builder.isNull(from.get("mySystem"))
: builder.equal(from.get("mySystem"), theScheme));
predicates.add(
isBlank(theVersion)
? builder.isNull(from.get("myVersion"))
: builder.equal(from.get("myVersion"), theVersion));
predicates.add(
isNull(theUserSelected)
? builder.isNull(from.get("myUserSelected"))
: builder.equal(from.get("myUserSelected"), theUserSelected));
cq.where(predicates.toArray(new Predicate[0]));
return myEntityManager.createQuery(cq);
}
void incrementId(T theResource, ResourceTable theSavedEntity, IIdType theResourceId) { void incrementId(T theResource, ResourceTable theSavedEntity, IIdType theResourceId) {
if (theResourceId == null || theResourceId.getVersionIdPart() == null) { if (theResourceId == null || theResourceId.getVersionIdPart() == null) {
theSavedEntity.initializeVersion(); theSavedEntity.initializeVersion();
@ -933,7 +807,7 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
@Override @Override
@CoverageIgnore @CoverageIgnore
public BaseHasResource readEntity(IIdType theValueId, RequestDetails theRequest) { public BaseHasResource readEntity(IIdType theValueId, RequestDetails theRequest) {
throw new NotImplementedException(Msg.code(927) + ""); throw new NotImplementedException(Msg.code(927));
} }
/** /**
@ -1840,7 +1714,7 @@ public abstract class BaseHapiFhirDao<T extends IBaseResource> extends BaseStora
@PostConstruct @PostConstruct
public void start() { public void start() {
// nothing yet this.tagDefinitionDao = new TagDefinitionDao(myEntityManagerFactory, myTransactionManager);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -0,0 +1,249 @@
package ca.uhn.fhir.jpa.dao;
import ca.uhn.fhir.i18n.Msg;
import ca.uhn.fhir.jpa.model.entity.TagDefinition;
import ca.uhn.fhir.jpa.model.entity.TagTypeEnum;
import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
import ca.uhn.fhir.util.ThreadPoolUtil;
import jakarta.persistence.EntityManager;
import jakarta.persistence.EntityManagerFactory;
import jakarta.persistence.NoResultException;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.Predicate;
import jakarta.persistence.criteria.Root;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import static java.util.Objects.isNull;
import static org.apache.commons.lang3.StringUtils.isBlank;
/**
* This class is responsible for getting or creating tags.
* The class is designed to do the tag retrieval in a separate thread and transaction to avoid any form of
* suspension of any current transaction, an operation that is not widely supported in an XA environment.
*/
public class TagDefinitionDao {
// total attempts to do a tag transaction
private static final int TOTAL_TAG_READ_ATTEMPTS = 10;
private static final Logger ourLog = LoggerFactory.getLogger(TagDefinitionDao.class);
private final ThreadPoolTaskExecutor taskExecutor = ThreadPoolUtil.newThreadPool(1, 5, "getOrCreateTag-");
private final EntityManagerFactory myEntityManagerFactory;
private final PlatformTransactionManager myTransactionManager;
public TagDefinitionDao(
EntityManagerFactory theEntityManagerFactory, PlatformTransactionManager theTransactionManager) {
this.myEntityManagerFactory = theEntityManagerFactory;
this.myTransactionManager = theTransactionManager;
}
/**
* Gets the tag defined by the fed in values, or saves it if it does not
* exist.
* <p>
* Can also throw an InternalErrorException if something bad happens.
*/
TagDefinition getOrCreateTag(
TagTypeEnum theTagType,
String theScheme,
String theTerm,
String theLabel,
String theVersion,
Boolean theUserSelected) {
try {
// Execute tag retrieval in a separate thread to avoid transaction suspension issues,
// which can cause problems in an XA environment that does not support suspend/resuming transactions,
// such as when using the PostgreSQL JDBC driver.
Future<TagDefinition> future = taskExecutor.submit(new RetryableGetOrCreateTagUnitOfWork(
myEntityManagerFactory,
myTransactionManager,
theTagType,
theScheme,
theTerm,
theLabel,
theVersion,
theUserSelected));
return future.get();
} catch (InterruptedException | ExecutionException e) {
throw new InternalErrorException(
Msg.code(2547) + "Tag get/create thread failed to execute: " + e.getMessage());
}
}
static class RetryableGetOrCreateTagUnitOfWork implements Callable<TagDefinition> {
private final EntityManagerFactory myEntityManagerFactory;
private final PlatformTransactionManager myTransactionManager;
private final TagTypeEnum theTagType;
private final String theScheme;
private final String theTerm;
private final String theLabel;
private final String theVersion;
private final Boolean theUserSelected;
public RetryableGetOrCreateTagUnitOfWork(
EntityManagerFactory theEntityManagerFactory,
PlatformTransactionManager theTransactionManager,
TagTypeEnum theTagType,
String theScheme,
String theTerm,
String theLabel,
String theVersion,
Boolean theUserSelected) {
this.myEntityManagerFactory = theEntityManagerFactory;
this.myTransactionManager = theTransactionManager;
this.theTagType = theTagType;
this.theScheme = theScheme;
this.theTerm = theTerm;
this.theLabel = theLabel;
this.theVersion = theVersion;
this.theUserSelected = theUserSelected;
}
@Override
public TagDefinition call() throws Exception {
// Create a new entity manager for the current thread.
try (EntityManager myEntityManager = myEntityManagerFactory.createEntityManager()) {
TransactionTemplate template = new TransactionTemplate(myTransactionManager);
template.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
// this transaction will attempt to get or create the tag,
// repeating (on any failure) 10 times.
// if it fails more than this, we will throw exceptions
TagDefinition retVal;
int count = 0;
HashSet<Throwable> throwables = new HashSet<>();
do {
try {
retVal = template.execute(new TransactionCallback<>() {
// do the actual DB call(s) to read and/or write the values
private TagDefinition readOrCreate() {
TagDefinition val;
try {
// Join transaction if needed.
if (!myEntityManager.isJoinedToTransaction()) {
myEntityManager.joinTransaction();
}
TypedQuery<TagDefinition> q =
buildTagQuery(myEntityManager, theTagType, theScheme, theTerm, theVersion, theUserSelected);
q.setMaxResults(1);
val = q.getSingleResult();
} catch (NoResultException e) {
val = new TagDefinition(theTagType, theScheme, theTerm, theLabel);
val.setVersion(theVersion);
val.setUserSelected(theUserSelected);
myEntityManager.persist(val);
myEntityManager.flush();
}
return val;
}
@Override
public TagDefinition doInTransaction(TransactionStatus status) {
TagDefinition tag = null;
try {
tag = readOrCreate();
} catch (Exception ex) {
// log any exceptions - just in case
// they may be signs of things to come...
ourLog.warn(
"Tag read/write failed: "
+ ex.getMessage() + ". "
+ "This is not a failure on its own, "
+ "but could be useful information in the result of an actual failure.",
ex);
throwables.add(ex);
}
return tag;
}
});
} catch (Exception ex) {
// transaction template can fail if connections to db are exhausted and/or timeout
ourLog.warn(
"Transaction failed with: {}. Transaction will rollback and be reattempted.",
ex.getMessage());
retVal = null;
}
// Clear the persistence context to avoid stale data on retry
if (retVal == null) {
myEntityManager.clear();
}
count++;
} while (retVal == null && count < TOTAL_TAG_READ_ATTEMPTS);
if (retVal == null) {
// if tag is still null,
// something bad must be happening
// - throw
String msg = throwables.stream().map(Throwable::getMessage).collect(Collectors.joining(", "));
throw new InternalErrorException(Msg.code(2023)
+ "Tag get/create failed after "
+ TOTAL_TAG_READ_ATTEMPTS
+ " attempts with error(s): "
+ msg);
}
return retVal;
}
}
private TypedQuery<TagDefinition> buildTagQuery(
EntityManager theEntityManager,
TagTypeEnum theTagType,
String theScheme,
String theTerm,
String theVersion,
Boolean theUserSelected) {
CriteriaBuilder builder = theEntityManager.getCriteriaBuilder();
CriteriaQuery<TagDefinition> cq = builder.createQuery(TagDefinition.class);
Root<TagDefinition> from = cq.from(TagDefinition.class);
List<Predicate> predicates = new ArrayList<>();
predicates.add(builder.and(
builder.equal(from.get("myTagType"), theTagType), builder.equal(from.get("myCode"), theTerm)));
predicates.add(
isBlank(theScheme)
? builder.isNull(from.get("mySystem"))
: builder.equal(from.get("mySystem"), theScheme));
predicates.add(
isBlank(theVersion)
? builder.isNull(from.get("myVersion"))
: builder.equal(from.get("myVersion"), theVersion));
predicates.add(
isNull(theUserSelected)
? builder.isNull(from.get("myUserSelected"))
: builder.equal(from.get("myUserSelected"), theUserSelected));
cq.where(predicates.toArray(new Predicate[0]));
TypedQuery<TagDefinition> query = theEntityManager.createQuery(cq);
query.setMaxResults(1);
return query;
}
}
}