Add tests for transaction sorting

This commit is contained in:
Tadgh 2021-04-20 14:12:25 -04:00
parent 1670ed7202
commit ff2690b74e
3 changed files with 200 additions and 16 deletions

View File

@ -1,6 +1,10 @@
package ca.uhn.fhir.util; package ca.uhn.fhir.util;
import ca.uhn.fhir.context.*; import ca.uhn.fhir.context.BaseRuntimeChildDefinition;
import ca.uhn.fhir.context.BaseRuntimeElementCompositeDefinition;
import ca.uhn.fhir.context.BaseRuntimeElementDefinition;
import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.context.RuntimeResourceDefinition;
import ca.uhn.fhir.rest.api.PatchTypeEnum; import ca.uhn.fhir.rest.api.PatchTypeEnum;
import ca.uhn.fhir.rest.api.RequestTypeEnum; import ca.uhn.fhir.rest.api.RequestTypeEnum;
import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException; import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException;
@ -9,17 +13,21 @@ import ca.uhn.fhir.util.bundle.BundleEntryParts;
import ca.uhn.fhir.util.bundle.EntryListAccumulator; import ca.uhn.fhir.util.bundle.EntryListAccumulator;
import ca.uhn.fhir.util.bundle.ModifiableBundleEntry; import ca.uhn.fhir.util.bundle.ModifiableBundleEntry;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.hl7.fhir.instance.model.api.*; import org.hl7.fhir.instance.model.api.IBase;
import org.hl7.fhir.instance.model.api.IBaseBinary;
import org.hl7.fhir.instance.model.api.IBaseBundle;
import org.hl7.fhir.instance.model.api.IBaseResource;
import org.hl7.fhir.instance.model.api.IPrimitiveType;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.apache.commons.lang3.StringUtils.isNotBlank; import static org.apache.commons.lang3.StringUtils.isNotBlank;
/* /*
* #%L * #%L
* HAPI FHIR - Core Library * HAPI FHIR - Core Library
@ -44,6 +52,8 @@ import static org.apache.commons.lang3.StringUtils.isNotBlank;
* Fetch resources from a bundle * Fetch resources from a bundle
*/ */
public class BundleUtil { public class BundleUtil {
private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(BundleUtil.class);
/** /**
* @return Returns <code>null</code> if the link isn't found or has no value * @return Returns <code>null</code> if the link isn't found or has no value
@ -176,18 +186,31 @@ public class BundleUtil {
static int GRAY = 2; static int GRAY = 2;
static int BLACK = 3; static int BLACK = 3;
public static IBaseBundle topologicalSort(FhirContext theContext, IBaseBundle theBundle) { public static List<BundleEntryParts> topologicalSort(FhirContext theContext, IBaseBundle theBundle, RequestTypeEnum theRequestTypeEnum) {
boolean isPossible = true; SortLegality legality = new SortLegality();
HashMap<String, Integer> color = new HashMap<String, Integer>(); HashMap<String, Integer> color = new HashMap<String, Integer>();
HashMap<String, List<String>> adjList = new HashMap<>(); HashMap<String, List<String>> adjList = new HashMap<>();
List<String> topologicalOrder = new ArrayList<>(); List<String> topologicalOrder = new ArrayList<>();
List<List<String>> prerequisites = new ArrayList<>();
List<BundleEntryParts> bundleEntryParts = toListOfEntries(theContext, theBundle); List<BundleEntryParts> bundleEntryParts = toListOfEntries(theContext, theBundle);
bundleEntryParts.removeIf(bep -> !bep.getRequestType().equals(theRequestTypeEnum));
HashMap<String, BundleEntryParts> resourceIdToBundleEntryMap = new HashMap<>();
for (BundleEntryParts bundleEntryPart : bundleEntryParts) { for (BundleEntryParts bundleEntryPart : bundleEntryParts) {
IBaseResource resource = bundleEntryPart.getResource(); IBaseResource resource = bundleEntryPart.getResource();
String resourceId = resource.getIdElement().toString(); String resourceId = resource.getIdElement().toString();
resourceIdToBundleEntryMap.put(resourceId, bundleEntryPart);
if (resourceId == null) {
if (bundleEntryPart.getFullUrl() != null) {
resourceId = bundleEntryPart.getFullUrl();
}
}
color.put(resourceId, WHITE);
}
for (BundleEntryParts bundleEntryPart : bundleEntryParts) {
IBaseResource resource = bundleEntryPart.getResource();
String resourceId = resource.getIdElement().toString();
resourceIdToBundleEntryMap.put(resourceId, bundleEntryPart);
if (resourceId == null) { if (resourceId == null) {
if (bundleEntryPart.getFullUrl() != null) { if (bundleEntryPart.getFullUrl() != null) {
resourceId = bundleEntryPart.getFullUrl(); resourceId = bundleEntryPart.getFullUrl();
@ -197,14 +220,80 @@ public class BundleUtil {
String finalResourceId = resourceId; String finalResourceId = resourceId;
allResourceReferences allResourceReferences
.forEach(refInfo -> { .forEach(refInfo -> {
prerequisites.add(Arrays.asList(finalResourceId, refInfo.getResourceReference().getReferenceElement().getValue())); String referencedResourceId = refInfo.getResourceReference().getReferenceElement().getValue();
}); if (color.containsKey(referencedResourceId)) {
if (!adjList.containsKey(finalResourceId)) {
adjList.put(finalResourceId, new ArrayList<>());
} }
System.out.println("zoop!!"); adjList.get(finalResourceId).add(refInfo.getResourceReference().getReferenceElement().getValue());
}
});
}
//All nodes are now white
//Adjacency List has been built.
for (Map.Entry<String, Integer> entry:color.entrySet()) {
if (entry.getValue() == WHITE) {
depthFirstSearch(entry.getKey(), color, adjList, topologicalOrder, legality);
}
}
if (legality.isLegal()) {
if (ourLog.isDebugEnabled()) {
ourLog.debug("Topological order is: {}", String.join(",", topologicalOrder));
}
List<BundleEntryParts> beps = new ArrayList<>();
for (int i = 0;i < topologicalOrder.size(); i++) {
BundleEntryParts bep = resourceIdToBundleEntryMap.get(topologicalOrder.get(i));
beps.add(bep);
}
//In case of delete, we want to delete child elements LAST.
if (theRequestTypeEnum.equals(RequestTypeEnum.DELETE)) {
Collections.reverse(beps);
}
return beps;
} else {
return null; return null;
} }
}
private static class SortLegality {
private boolean myIsLegal;
SortLegality() {
this.myIsLegal = true;
}
private void setLegal(boolean theLegal) {
myIsLegal = theLegal;
}
public boolean isLegal() {
return myIsLegal;
}
}
private static void depthFirstSearch(String theResourceId, HashMap<String, Integer> theResourceIdToColor, HashMap<String, List<String>> theAdjList, List<String> theTopologicalOrder, SortLegality theLegality) {
System.out.println("RECURSING ON " + theResourceId);
if (!theLegality.isLegal()) {
System.out.println("IMPOSSIBLE!");
return;
}
//We are currently recursing over this node (gray)
theResourceIdToColor.put(theResourceId, GRAY);
for (String neighbourResourceId: theAdjList.getOrDefault(theResourceId, new ArrayList<>())) {
if (theResourceIdToColor.get(neighbourResourceId) == WHITE) {
depthFirstSearch(neighbourResourceId, theResourceIdToColor, theAdjList, theTopologicalOrder, theLegality);
} else if (theResourceIdToColor.get(neighbourResourceId) == GRAY) {
theLegality.setLegal(false);
return;
}
}
//Mark the node as black
theResourceIdToColor.put(theResourceId, BLACK);
theTopologicalOrder.add(theResourceId);
}
public static void processEntries(FhirContext theContext, IBaseBundle theBundle, Consumer<ModifiableBundleEntry> theProcessor) { public static void processEntries(FhirContext theContext, IBaseBundle theBundle, Consumer<ModifiableBundleEntry> theProcessor) {
RuntimeResourceDefinition bundleDef = theContext.getResourceDefinition(theBundle); RuntimeResourceDefinition bundleDef = theContext.getResourceDefinition(theBundle);

View File

@ -58,5 +58,4 @@ public class BundleEntryParts {
public String getUrl() { public String getUrl() {
return myUrl; return myUrl;
} }
} }

View File

@ -4,10 +4,12 @@ import ca.uhn.fhir.interceptor.api.HookParams;
import ca.uhn.fhir.interceptor.api.IInterceptorService; import ca.uhn.fhir.interceptor.api.IInterceptorService;
import ca.uhn.fhir.interceptor.api.Pointcut; import ca.uhn.fhir.interceptor.api.Pointcut;
import ca.uhn.fhir.jpa.api.model.DaoMethodOutcome; import ca.uhn.fhir.jpa.api.model.DaoMethodOutcome;
import ca.uhn.fhir.rest.api.RequestTypeEnum;
import ca.uhn.fhir.rest.api.server.storage.DeferredInterceptorBroadcasts; import ca.uhn.fhir.rest.api.server.storage.DeferredInterceptorBroadcasts;
import ca.uhn.fhir.rest.server.exceptions.ResourceGoneException; import ca.uhn.fhir.rest.server.exceptions.ResourceGoneException;
import ca.uhn.fhir.rest.server.exceptions.ResourceVersionConflictException; import ca.uhn.fhir.rest.server.exceptions.ResourceVersionConflictException;
import ca.uhn.fhir.util.BundleUtil; import ca.uhn.fhir.util.BundleUtil;
import ca.uhn.fhir.util.bundle.BundleEntryParts;
import ca.uhn.test.concurrency.PointcutLatch; import ca.uhn.test.concurrency.PointcutLatch;
import com.google.common.collect.ListMultimap; import com.google.common.collect.ListMultimap;
import org.hl7.fhir.instance.model.api.IBaseResource; import org.hl7.fhir.instance.model.api.IBaseResource;
@ -24,14 +26,17 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import java.util.Collections;
import java.util.List; import java.util.List;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.matchesPattern; import static org.hamcrest.Matchers.matchesPattern;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
import static org.hamcrest.Matchers.is;
public class TransactionHookTest extends BaseJpaR4SystemTest { public class TransactionHookTest extends BaseJpaR4SystemTest {
@ -54,7 +59,7 @@ public class TransactionHookTest extends BaseJpaR4SystemTest {
} }
@Test @Test
public void testTopologicalTransactionSorting() { public void testTopologicalTransactionSortForCreates() {
Bundle b = new Bundle(); Bundle b = new Bundle();
Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry(); Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry();
@ -87,8 +92,99 @@ public class TransactionHookTest extends BaseJpaR4SystemTest {
organizationComponent.setResource(org1); organizationComponent.setResource(org1);
organizationComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Patient"); organizationComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Patient");
BundleUtil.topologicalSort(myFhirCtx, b); List<BundleEntryParts> postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.POST);
assertThat(postBundleEntries, hasSize(4));
int observationIndex = getIndexOfEntryWithId("Observation/O1", postBundleEntries);
int patientIndex = getIndexOfEntryWithId("Patient/P1", postBundleEntries);
int organizationIndex = getIndexOfEntryWithId("Organization/Org1", postBundleEntries);
assertTrue(organizationIndex < patientIndex);
assertTrue(patientIndex < observationIndex);
}
@Test
public void testTransactionSorterFailsOnCyclicReference() {
Bundle b = new Bundle();
Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry();
final Observation obs1 = new Observation();
obs1.setStatus(Observation.ObservationStatus.FINAL);
obs1.setSubject(new Reference("Patient/P1"));
obs1.setValue(new Quantity(4));
obs1.setId("Observation/O1");
obs1.setHasMember(Collections.singletonList(new Reference("Observation/O2")));
bundleEntryComponent.setResource(obs1);
bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Observation");
bundleEntryComponent = b.addEntry();
final Observation obs2 = new Observation();
obs2.setStatus(Observation.ObservationStatus.FINAL);
obs2.setValue(new Quantity(4));
obs2.setId("Observation/O2");
obs2.setHasMember(Collections.singletonList(new Reference("Observation/O1")));
bundleEntryComponent.setResource(obs2);
bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Observation");
List<BundleEntryParts> postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.POST);
//Null value indicates that we hit a cycle, and could not process the deletions in order
assertThat(postBundleEntries, is(nullValue()));
}
@Test
public void testTransactionSorterReturnsDeletesInCorrectProcessingOrder() {
Bundle b = new Bundle();
Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry();
final Observation obs1 = new Observation();
obs1.setStatus(Observation.ObservationStatus.FINAL);
obs1.setSubject(new Reference("Patient/P1"));
obs1.setValue(new Quantity(4));
obs1.setId("Observation/O1");
bundleEntryComponent.setResource(obs1);
bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Observation");
bundleEntryComponent = b.addEntry();
final Observation obs2 = new Observation();
obs2.setStatus(Observation.ObservationStatus.FINAL);
obs2.setValue(new Quantity(4));
obs2.setId("Observation/O2");
bundleEntryComponent.setResource(obs2);
bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Observation");
Bundle.BundleEntryComponent patientComponent = b.addEntry();
Patient pat1 = new Patient();
pat1.setId("Patient/P1");
pat1.setManagingOrganization(new Reference("Organization/Org1"));
patientComponent.setResource(pat1);
patientComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Patient");
Bundle.BundleEntryComponent organizationComponent = b.addEntry();
Organization org1 = new Organization();
org1.setId("Organization/Org1");
organizationComponent.setResource(org1);
organizationComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Organization");
List<BundleEntryParts> postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.DELETE);
assertThat(postBundleEntries, hasSize(4));
int observationIndex = getIndexOfEntryWithId("Observation/O1", postBundleEntries);
int patientIndex = getIndexOfEntryWithId("Patient/P1", postBundleEntries);
int organizationIndex = getIndexOfEntryWithId("Organization/Org1", postBundleEntries);
assertTrue(patientIndex < organizationIndex);
assertTrue(observationIndex < patientIndex);
}
private int getIndexOfEntryWithId(String theResourceId, List<BundleEntryParts> theBundleEntryParts) {
for (int i = 0; i < theBundleEntryParts.size(); i++) {
String id = theBundleEntryParts.get(i).getResource().getIdElement().toUnqualifiedVersionless().toString();
if (id.equals(theResourceId)) {
return i;
}
}
fail("Didn't find resource with ID " + theResourceId);
return -1;
} }
@Test @Test