Avoid query cache poisoning (#2647)

* Avoid query cache poisoning

* Test fixes

* Add changelog
This commit is contained in:
James Agnew 2021-05-11 08:39:38 -04:00 committed by GitHub
parent 43631d4937
commit 3015438a0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 131 additions and 13 deletions

View File

@ -0,0 +1,5 @@
---
type: fix
issue: 2647
title: "The new Match URL cache suffered from potential cache poisoning if multiple threads performed
a condiitonal create operation at the same time."

View File

@ -179,6 +179,7 @@ public abstract class BaseHapiFhirResourceDao<T extends IBaseResource> extends B
private IRequestPartitionHelperSvc myPartitionHelperSvc;
@Autowired
private MemoryCacheService myMemoryCacheService;
private TransactionTemplate myTxTemplate;
@Override
public DaoMethodOutcome create(final T theResource) {
@ -273,14 +274,22 @@ public abstract class BaseHapiFhirResourceDao<T extends IBaseResource> extends B
};
Supplier<IIdType> idSupplier = () -> {
IIdType retVal = myIdHelperService.translatePidIdToForcedId(myFhirContext, myResourceName, pid);
if (!retVal.hasVersionIdPart()) {
return myMemoryCacheService.get(MemoryCacheService.CacheEnum.RESOURCE_CONDITIONAL_CREATE_VERSION, retVal, t -> {
long version = myResourceTableDao.findCurrentVersionByPid(pid.getIdAsLong());
return myFhirContext.getVersion().newIdType().setParts(retVal.getBaseUrl(), retVal.getResourceType(), retVal.getIdPart(), Long.toString(version));
});
}
return retVal;
return myTxTemplate.execute(tx-> {
IIdType retVal = myIdHelperService.translatePidIdToForcedId(myFhirContext, myResourceName, pid);
if (!retVal.hasVersionIdPart()) {
IIdType idWithVersion = myMemoryCacheService.getIfPresent(MemoryCacheService.CacheEnum.RESOURCE_CONDITIONAL_CREATE_VERSION, pid.getIdAsLong());
if (idWithVersion == null) {
Long version = myResourceTableDao.findCurrentVersionByPid(pid.getIdAsLong());
if (version != null) {
retVal = myFhirContext.getVersion().newIdType().setParts(retVal.getBaseUrl(), retVal.getResourceType(), retVal.getIdPart(), Long.toString(version));
myMemoryCacheService.putAfterCommit(MemoryCacheService.CacheEnum.RESOURCE_CONDITIONAL_CREATE_VERSION, pid.getIdAsLong(), retVal);
}
} else {
retVal = idWithVersion;
}
}
return retVal;
});
};
return toMethodOutcomeLazy(theRequest, pid, entitySupplier, idSupplier).setCreated(false).setNop(true);
@ -1057,6 +1066,7 @@ public abstract class BaseHapiFhirResourceDao<T extends IBaseResource> extends B
public void start() {
ourLog.debug("Starting resource DAO for type: {}", getResourceName());
myInstanceValidator = getApplicationContext().getBean(IInstanceValidatorModule.class);
myTxTemplate = new TransactionTemplate(myPlatformTransactionManager);
super.start();
}

View File

@ -43,6 +43,8 @@ import org.apache.commons.lang3.Validate;
import org.hl7.fhir.instance.model.api.IBaseResource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import java.util.Collections;
import java.util.Set;
@ -80,7 +82,8 @@ public class MatchResourceUrlService {
Set<ResourcePersistentId> retVal = search(paramMap, theResourceType, theRequest);
if (myDaoConfig.getMatchUrlCache() && retVal.size() == 1) {
myMemoryCacheService.put(MemoryCacheService.CacheEnum.MATCH_URL, theMatchUrl, retVal.iterator().next());
ResourcePersistentId pid = retVal.iterator().next();
myMemoryCacheService.putAfterCommit(MemoryCacheService.CacheEnum.MATCH_URL, theMatchUrl, pid);
}
return retVal;
@ -113,7 +116,7 @@ public class MatchResourceUrlService {
Validate.notBlank(theMatchUrl);
Validate.notNull(theResourcePersistentId);
if (myDaoConfig.getMatchUrlCache()) {
myMemoryCacheService.put(MemoryCacheService.CacheEnum.MATCH_URL, theMatchUrl, theResourcePersistentId);
myMemoryCacheService.putAfterCommit(MemoryCacheService.CacheEnum.MATCH_URL, theMatchUrl, theResourcePersistentId);
}
}
}

View File

@ -93,5 +93,5 @@ public interface IResourceTableDao extends JpaRepository<ResourceTable, Long> {
Collection<Object[]> findLookupFieldsByResourcePidInPartitionNull(@Param("pid") List<Long> thePids);
@Query("SELECT t.myVersion FROM ResourceTable t WHERE t.myId = :pid")
long findCurrentVersionByPid(@Param("pid") Long thePid);
Long findCurrentVersionByPid(@Param("pid") Long thePid);
}

View File

@ -23,12 +23,15 @@ package ca.uhn.fhir.jpa.util;
import ca.uhn.fhir.jpa.api.config.DaoConfig;
import ca.uhn.fhir.jpa.api.model.TranslationQuery;
import ca.uhn.fhir.jpa.model.entity.TagTypeEnum;
import ca.uhn.fhir.rest.api.server.storage.ResourcePersistentId;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.hl7.fhir.instance.model.api.IIdType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import javax.annotation.Nonnull;
import javax.annotation.PostConstruct;
@ -100,6 +103,19 @@ public class MemoryCacheService {
getCache(theCache).put(theKey, theValue);
}
public <K, V> void putAfterCommit(CacheEnum theCache, K theKey, V theValue) {
if (TransactionSynchronizationManager.isSynchronizationActive()) {
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
@Override
public void afterCommit() {
put(theCache, theKey, theValue);
}
});
} else {
put(theCache, theKey, theValue);
}
}
public <K, V> Map<K, V> getAllPresent(CacheEnum theCache, Iterable<K> theKeys) {
return (Map<K, V>) getCache(theCache).getAllPresent(theKeys);
}
@ -122,7 +138,7 @@ public class MemoryCacheService {
CONCEPT_TRANSLATION(TranslationQuery.class),
MATCH_URL(String.class),
CONCEPT_TRANSLATION_REVERSE(TranslationQuery.class),
RESOURCE_CONDITIONAL_CREATE_VERSION(IIdType.class),
RESOURCE_CONDITIONAL_CREATE_VERSION(Long.class),
HISTORY_COUNT(HistoryCountKey.class);
private final Class<?> myKeyType;

View File

@ -11,15 +11,20 @@ import ca.uhn.fhir.rest.server.RestfulServer;
import ca.uhn.fhir.rest.server.exceptions.PreconditionFailedException;
import ca.uhn.fhir.rest.server.exceptions.ResourceVersionConflictException;
import ca.uhn.fhir.rest.server.servlet.ServletRequestDetails;
import ca.uhn.fhir.util.BundleBuilder;
import ca.uhn.fhir.util.HapiExtensions;
import org.hl7.fhir.instance.model.api.IIdType;
import org.hl7.fhir.r4.model.BooleanType;
import org.hl7.fhir.r4.model.Bundle;
import org.hl7.fhir.r4.model.CodeType;
import org.hl7.fhir.r4.model.Coverage;
import org.hl7.fhir.r4.model.Enumerations;
import org.hl7.fhir.r4.model.ExplanationOfBenefit;
import org.hl7.fhir.r4.model.IdType;
import org.hl7.fhir.r4.model.Observation;
import org.hl7.fhir.r4.model.Parameters;
import org.hl7.fhir.r4.model.Patient;
import org.hl7.fhir.r4.model.Practitioner;
import org.hl7.fhir.r4.model.SearchParameter;
import org.hl7.fhir.r4.model.StringType;
import org.junit.jupiter.api.AfterEach;
@ -35,6 +40,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static org.hamcrest.MatcherAssert.assertThat;
@ -70,6 +76,83 @@ public class FhirResourceDaoR4ConcurrentWriteTest extends BaseJpaR4Test {
myInterceptorRegistry.unregisterInterceptor(myRetryInterceptor);
}
@Test
public void testConcurrentTransactionCreates() throws ExecutionException, InterruptedException {
myDaoConfig.setMatchUrlCache(true);
AtomicInteger counter = new AtomicInteger(0);
Runnable creator = () -> {
BundleBuilder bb = new BundleBuilder(myFhirCtx);
String patientId = "Patient/PT" + counter.get();
IdType practitionerId = IdType.newRandomUuid();
IdType practitionerId2 = IdType.newRandomUuid();
ExplanationOfBenefit eob = new ExplanationOfBenefit();
eob.addIdentifier().setSystem("foo").setValue("" + counter.get());
eob.getPatient().setReference(patientId);
eob.addCareTeam().getProvider().setReference(practitionerId.getValue());
eob.addCareTeam().getProvider().setReference(practitionerId2.getValue());
bb.addTransactionUpdateEntry(eob).conditional("ExplanationOfBenefit?identifier=foo|" + counter.get());
Patient pt = new Patient();
pt.setId(patientId);
pt.setActive(true);
bb.addTransactionUpdateEntry(pt);
Coverage coverage = new Coverage();
coverage.addIdentifier().setSystem("foo").setValue("" + counter.get());
coverage.getBeneficiary().setReference(patientId);
bb.addTransactionUpdateEntry(coverage).conditional("Coverage?identifier=foo|" + counter.get());
Practitioner practitioner = new Practitioner();
practitioner.setId(practitionerId);
practitioner.addIdentifier().setSystem("foo").setValue("" + counter.get());
bb.addTransactionCreateEntry(practitioner).conditional("Practitioner?identifier=foo|" + counter.get());
Practitioner practitioner2 = new Practitioner();
practitioner2.setId(practitionerId2);
practitioner2.addIdentifier().setSystem("foo2").setValue("" + counter.get());
bb.addTransactionCreateEntry(practitioner2).conditional("Practitioner?identifier=foo2|" + counter.get());
Observation obs = new Observation();
obs.setId("Observation/OBS" + counter);
obs.getSubject().setReference(pt.getId());
bb.addTransactionUpdateEntry(obs);
Bundle input = (Bundle) bb.getBundle();
mySystemDao.transaction(new SystemRequestDetails(), input);
};
for (int i = 0; i < 10; i++) {
counter.set(i);
ourLog.info("*********************************************************************************");
ourLog.info("Starting pass {}", i);
ourLog.info("*********************************************************************************");
List<Future<?>> futures = new ArrayList<>();
for (int j = 0; j < 10; j++) {
futures.add(myExecutor.submit(creator));
}
for (Future<?> next : futures) {
try {
next.get();
} catch (Exception e) {
// ignore
}
}
creator.run();
}
runInTransaction(()->{
assertEquals(60, myResourceTableDao.count());
});
}
@Test
public void testCreateWithClientAssignedId() {
myInterceptorRegistry.registerInterceptor(myRetryInterceptor);

View File

@ -698,7 +698,7 @@ public class FhirResourceDaoR4QueryCountTest extends BaseJpaR4Test {
myCaptureQueriesListener.clear();
mySystemDao.transaction(mySrd, bundleCreator.get());
assertEquals(1, myCaptureQueriesListener.countSelectQueries());
assertEquals(2, myCaptureQueriesListener.countSelectQueries());
assertEquals(5, myCaptureQueriesListener.countInsertQueries());
assertEquals(0, myCaptureQueriesListener.countDeleteQueries());

View File

@ -85,6 +85,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static org.hamcrest.MatcherAssert.assertThat;