diff --git a/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/BundleUtil.java b/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/BundleUtil.java index 1d48b6a8fa9..471c317b8da 100644 --- a/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/BundleUtil.java +++ b/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/BundleUtil.java @@ -1,6 +1,10 @@ package ca.uhn.fhir.util; -import ca.uhn.fhir.context.*; +import ca.uhn.fhir.context.BaseRuntimeChildDefinition; +import ca.uhn.fhir.context.BaseRuntimeElementCompositeDefinition; +import ca.uhn.fhir.context.BaseRuntimeElementDefinition; +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.context.RuntimeResourceDefinition; import ca.uhn.fhir.rest.api.PatchTypeEnum; import ca.uhn.fhir.rest.api.RequestTypeEnum; import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException; @@ -9,17 +13,21 @@ import ca.uhn.fhir.util.bundle.BundleEntryParts; import ca.uhn.fhir.util.bundle.EntryListAccumulator; import ca.uhn.fhir.util.bundle.ModifiableBundleEntry; import org.apache.commons.lang3.tuple.Pair; -import org.hl7.fhir.instance.model.api.*; +import org.hl7.fhir.instance.model.api.IBase; +import org.hl7.fhir.instance.model.api.IBaseBinary; +import org.hl7.fhir.instance.model.api.IBaseBundle; +import org.hl7.fhir.instance.model.api.IBaseResource; +import org.hl7.fhir.instance.model.api.IPrimitiveType; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.function.Consumer; import static org.apache.commons.lang3.StringUtils.isNotBlank; - /* * #%L * HAPI FHIR - Core Library @@ -44,6 +52,8 @@ import static org.apache.commons.lang3.StringUtils.isNotBlank; * Fetch resources from a bundle */ public class BundleUtil { + private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(BundleUtil.class); + /** * @return Returns null if the link isn't found or has no value @@ -176,18 +186,31 @@ public class BundleUtil { static int GRAY = 2; static int BLACK = 3; - public static IBaseBundle topologicalSort(FhirContext theContext, IBaseBundle theBundle) { - boolean isPossible = true; + public static List topologicalSort(FhirContext theContext, IBaseBundle theBundle, RequestTypeEnum theRequestTypeEnum) { + SortLegality legality = new SortLegality(); HashMap color = new HashMap(); HashMap> adjList = new HashMap<>(); List topologicalOrder = new ArrayList<>(); - - List> prerequisites = new ArrayList<>(); - List bundleEntryParts = toListOfEntries(theContext, theBundle); + bundleEntryParts.removeIf(bep -> !bep.getRequestType().equals(theRequestTypeEnum)); + HashMap resourceIdToBundleEntryMap = new HashMap<>(); + for (BundleEntryParts bundleEntryPart : bundleEntryParts) { IBaseResource resource = bundleEntryPart.getResource(); String resourceId = resource.getIdElement().toString(); + resourceIdToBundleEntryMap.put(resourceId, bundleEntryPart); + if (resourceId == null) { + if (bundleEntryPart.getFullUrl() != null) { + resourceId = bundleEntryPart.getFullUrl(); + } + } + color.put(resourceId, WHITE); + } + + for (BundleEntryParts bundleEntryPart : bundleEntryParts) { + IBaseResource resource = bundleEntryPart.getResource(); + String resourceId = resource.getIdElement().toString(); + resourceIdToBundleEntryMap.put(resourceId, bundleEntryPart); if (resourceId == null) { if (bundleEntryPart.getFullUrl() != null) { resourceId = bundleEntryPart.getFullUrl(); @@ -197,14 +220,80 @@ public class BundleUtil { String finalResourceId = resourceId; allResourceReferences .forEach(refInfo -> { - prerequisites.add(Arrays.asList(finalResourceId, refInfo.getResourceReference().getReferenceElement().getValue())); + String referencedResourceId = refInfo.getResourceReference().getReferenceElement().getValue(); + if (color.containsKey(referencedResourceId)) { + if (!adjList.containsKey(finalResourceId)) { + adjList.put(finalResourceId, new ArrayList<>()); + } + adjList.get(finalResourceId).add(refInfo.getResourceReference().getReferenceElement().getValue()); + } }); - } - System.out.println("zoop!!"); - return null; + //All nodes are now white + //Adjacency List has been built. + + for (Map.Entry entry:color.entrySet()) { + if (entry.getValue() == WHITE) { + depthFirstSearch(entry.getKey(), color, adjList, topologicalOrder, legality); + } + } + if (legality.isLegal()) { + if (ourLog.isDebugEnabled()) { + ourLog.debug("Topological order is: {}", String.join(",", topologicalOrder)); + } + List beps = new ArrayList<>(); + + for (int i = 0;i < topologicalOrder.size(); i++) { + BundleEntryParts bep = resourceIdToBundleEntryMap.get(topologicalOrder.get(i)); + beps.add(bep); + } + + //In case of delete, we want to delete child elements LAST. + if (theRequestTypeEnum.equals(RequestTypeEnum.DELETE)) { + Collections.reverse(beps); + } + return beps; + } else { + return null; + } } + private static class SortLegality { + private boolean myIsLegal; + + SortLegality() { + this.myIsLegal = true; + } + private void setLegal(boolean theLegal) { + myIsLegal = theLegal; + } + + public boolean isLegal() { + return myIsLegal; + } + } + private static void depthFirstSearch(String theResourceId, HashMap theResourceIdToColor, HashMap> theAdjList, List theTopologicalOrder, SortLegality theLegality) { + System.out.println("RECURSING ON " + theResourceId); + if (!theLegality.isLegal()) { + System.out.println("IMPOSSIBLE!"); + return; + } + + //We are currently recursing over this node (gray) + theResourceIdToColor.put(theResourceId, GRAY); + + for (String neighbourResourceId: theAdjList.getOrDefault(theResourceId, new ArrayList<>())) { + if (theResourceIdToColor.get(neighbourResourceId) == WHITE) { + depthFirstSearch(neighbourResourceId, theResourceIdToColor, theAdjList, theTopologicalOrder, theLegality); + } else if (theResourceIdToColor.get(neighbourResourceId) == GRAY) { + theLegality.setLegal(false); + return; + } + } + //Mark the node as black + theResourceIdToColor.put(theResourceId, BLACK); + theTopologicalOrder.add(theResourceId); + } public static void processEntries(FhirContext theContext, IBaseBundle theBundle, Consumer theProcessor) { RuntimeResourceDefinition bundleDef = theContext.getResourceDefinition(theBundle); diff --git a/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/bundle/BundleEntryParts.java b/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/bundle/BundleEntryParts.java index 5d698cec7b3..7bc9ecfa456 100644 --- a/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/bundle/BundleEntryParts.java +++ b/hapi-fhir-base/src/main/java/ca/uhn/fhir/util/bundle/BundleEntryParts.java @@ -58,5 +58,4 @@ public class BundleEntryParts { public String getUrl() { return myUrl; } - } diff --git a/hapi-fhir-jpaserver-base/src/test/java/ca/uhn/fhir/jpa/dao/r4/TransactionHookTest.java b/hapi-fhir-jpaserver-base/src/test/java/ca/uhn/fhir/jpa/dao/r4/TransactionHookTest.java index 57e4ae32ec6..0d046d782cb 100644 --- a/hapi-fhir-jpaserver-base/src/test/java/ca/uhn/fhir/jpa/dao/r4/TransactionHookTest.java +++ b/hapi-fhir-jpaserver-base/src/test/java/ca/uhn/fhir/jpa/dao/r4/TransactionHookTest.java @@ -4,10 +4,12 @@ import ca.uhn.fhir.interceptor.api.HookParams; import ca.uhn.fhir.interceptor.api.IInterceptorService; import ca.uhn.fhir.interceptor.api.Pointcut; import ca.uhn.fhir.jpa.api.model.DaoMethodOutcome; +import ca.uhn.fhir.rest.api.RequestTypeEnum; import ca.uhn.fhir.rest.api.server.storage.DeferredInterceptorBroadcasts; import ca.uhn.fhir.rest.server.exceptions.ResourceGoneException; import ca.uhn.fhir.rest.server.exceptions.ResourceVersionConflictException; import ca.uhn.fhir.util.BundleUtil; +import ca.uhn.fhir.util.bundle.BundleEntryParts; import ca.uhn.test.concurrency.PointcutLatch; import com.google.common.collect.ListMultimap; import org.hl7.fhir.instance.model.api.IBaseResource; @@ -24,14 +26,17 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import java.util.Collections; import java.util.List; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.matchesPattern; +import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.hamcrest.Matchers.is; public class TransactionHookTest extends BaseJpaR4SystemTest { @@ -54,7 +59,7 @@ public class TransactionHookTest extends BaseJpaR4SystemTest { } @Test - public void testTopologicalTransactionSorting() { + public void testTopologicalTransactionSortForCreates() { Bundle b = new Bundle(); Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry(); @@ -87,8 +92,99 @@ public class TransactionHookTest extends BaseJpaR4SystemTest { organizationComponent.setResource(org1); organizationComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Patient"); - BundleUtil.topologicalSort(myFhirCtx, b); + List postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.POST); + assertThat(postBundleEntries, hasSize(4)); + + int observationIndex = getIndexOfEntryWithId("Observation/O1", postBundleEntries); + int patientIndex = getIndexOfEntryWithId("Patient/P1", postBundleEntries); + int organizationIndex = getIndexOfEntryWithId("Organization/Org1", postBundleEntries); + + assertTrue(organizationIndex < patientIndex); + assertTrue(patientIndex < observationIndex); + } + + @Test + public void testTransactionSorterFailsOnCyclicReference() { + Bundle b = new Bundle(); + Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry(); + final Observation obs1 = new Observation(); + obs1.setStatus(Observation.ObservationStatus.FINAL); + obs1.setSubject(new Reference("Patient/P1")); + obs1.setValue(new Quantity(4)); + obs1.setId("Observation/O1"); + obs1.setHasMember(Collections.singletonList(new Reference("Observation/O2"))); + bundleEntryComponent.setResource(obs1); + bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Observation"); + + bundleEntryComponent = b.addEntry(); + final Observation obs2 = new Observation(); + obs2.setStatus(Observation.ObservationStatus.FINAL); + obs2.setValue(new Quantity(4)); + obs2.setId("Observation/O2"); + obs2.setHasMember(Collections.singletonList(new Reference("Observation/O1"))); + bundleEntryComponent.setResource(obs2); + bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.POST).setUrl("Observation"); + List postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.POST); + + //Null value indicates that we hit a cycle, and could not process the deletions in order + assertThat(postBundleEntries, is(nullValue())); + } + + @Test + public void testTransactionSorterReturnsDeletesInCorrectProcessingOrder() { + Bundle b = new Bundle(); + Bundle.BundleEntryComponent bundleEntryComponent = b.addEntry(); + final Observation obs1 = new Observation(); + obs1.setStatus(Observation.ObservationStatus.FINAL); + obs1.setSubject(new Reference("Patient/P1")); + obs1.setValue(new Quantity(4)); + obs1.setId("Observation/O1"); + bundleEntryComponent.setResource(obs1); + bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Observation"); + + bundleEntryComponent = b.addEntry(); + final Observation obs2 = new Observation(); + obs2.setStatus(Observation.ObservationStatus.FINAL); + obs2.setValue(new Quantity(4)); + obs2.setId("Observation/O2"); + bundleEntryComponent.setResource(obs2); + bundleEntryComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Observation"); + + Bundle.BundleEntryComponent patientComponent = b.addEntry(); + Patient pat1 = new Patient(); + pat1.setId("Patient/P1"); + pat1.setManagingOrganization(new Reference("Organization/Org1")); + patientComponent.setResource(pat1); + patientComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Patient"); + + Bundle.BundleEntryComponent organizationComponent = b.addEntry(); + Organization org1 = new Organization(); + org1.setId("Organization/Org1"); + organizationComponent.setResource(org1); + organizationComponent.getRequest().setMethod(Bundle.HTTPVerb.DELETE).setUrl("Organization"); + + List postBundleEntries = BundleUtil.topologicalSort(myFhirCtx, b, RequestTypeEnum.DELETE); + + assertThat(postBundleEntries, hasSize(4)); + + int observationIndex = getIndexOfEntryWithId("Observation/O1", postBundleEntries); + int patientIndex = getIndexOfEntryWithId("Patient/P1", postBundleEntries); + int organizationIndex = getIndexOfEntryWithId("Organization/Org1", postBundleEntries); + + assertTrue(patientIndex < organizationIndex); + assertTrue(observationIndex < patientIndex); + } + + private int getIndexOfEntryWithId(String theResourceId, List theBundleEntryParts) { + for (int i = 0; i < theBundleEntryParts.size(); i++) { + String id = theBundleEntryParts.get(i).getResource().getIdElement().toUnqualifiedVersionless().toString(); + if (id.equals(theResourceId)) { + return i; + } + } + fail("Didn't find resource with ID " + theResourceId); + return -1; } @Test