4693 enable partitioning in bulk import (#4694)

* created failing tests

* implemented feature, added more tests

* added documentation and changelog

* fixed duplicate error code

* code review changes

---------

Co-authored-by: Steven Li <steven@smilecdr.com>
This commit is contained in:
StevenXLi 2023-03-31 09:43:40 -04:00 committed by GitHub
parent ccd1e94a47
commit 3fb9a16975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 273 additions and 49 deletions

View File

@ -0,0 +1,4 @@
---
type: change
issue: 4693
title: "Bulk import operations have been enhanced to be fully partition aware."

View File

@ -4,3 +4,5 @@ These changes will be applied automatically on first startup.
To avoid this delay on first startup, run the migration manually. To avoid this delay on first startup, run the migration manually.
Bulk export behaviour is changing in this release such that Binary resources created as part of the response will now be created in the partition that the bulk export was requested rather than in the DEFAULT partition as was being done previously. Bulk export behaviour is changing in this release such that Binary resources created as part of the response will now be created in the partition that the bulk export was requested rather than in the DEFAULT partition as was being done previously.
Bulk import behaviour is changing in this release such that data imported as part of the request will now create resources in the partition that the bulk import was requested rather than in the DEFAULT partition as was being done previously.

View File

@ -2,15 +2,21 @@ package ca.uhn.fhir.jpa.bulk.imprt2;
import ca.uhn.fhir.batch2.api.JobExecutionFailedException; import ca.uhn.fhir.batch2.api.JobExecutionFailedException;
import ca.uhn.fhir.batch2.jobs.imprt.ConsumeFilesStep; import ca.uhn.fhir.batch2.jobs.imprt.ConsumeFilesStep;
import ca.uhn.fhir.jpa.test.BaseJpaR4Test; import ca.uhn.fhir.interceptor.model.RequestPartitionId;
import ca.uhn.fhir.jpa.dao.r4.BasePartitioningR4Test;
import org.hl7.fhir.instance.model.api.IBaseResource; import org.hl7.fhir.instance.model.api.IBaseResource;
import org.hl7.fhir.r4.model.IdType; import org.hl7.fhir.r4.model.IdType;
import org.hl7.fhir.r4.model.Patient; import org.hl7.fhir.r4.model.Patient;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder; import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import javax.servlet.ServletException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -23,11 +29,24 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@TestMethodOrder(MethodOrderer.MethodName.class) @TestMethodOrder(MethodOrderer.MethodName.class)
public class ConsumeFilesStepR4Test extends BaseJpaR4Test { public class ConsumeFilesStepR4Test extends BasePartitioningR4Test {
@Autowired @Autowired
private ConsumeFilesStep mySvc; private ConsumeFilesStep mySvc;
private final RequestPartitionId myRequestPartitionId = RequestPartitionId.fromPartitionIdAndName(1, "PART-1");
@BeforeEach
@Override
public void before() throws ServletException {
super.before();
myPartitionSettings.setPartitioningEnabled(false);
}
@AfterEach
@Override
public void after() {
super.after();
}
@Test @Test
public void testAlreadyExisting_NoChanges() { public void testAlreadyExisting_NoChanges() {
// Setup // Setup
@ -59,11 +78,11 @@ public class ConsumeFilesStepR4Test extends BaseJpaR4Test {
myMemoryCacheService.invalidateAllCaches(); myMemoryCacheService.invalidateAllCaches();
myCaptureQueriesListener.clear(); myCaptureQueriesListener.clear();
mySvc.storeResources(resources); mySvc.storeResources(resources, null);
// Validate // Validate
assertEquals(4, myCaptureQueriesListener.logSelectQueries().size()); assertEquals(7, myCaptureQueriesListener.logSelectQueries().size());
assertEquals(0, myCaptureQueriesListener.countInsertQueries()); assertEquals(0, myCaptureQueriesListener.countInsertQueries());
assertEquals(0, myCaptureQueriesListener.countUpdateQueries()); assertEquals(0, myCaptureQueriesListener.countUpdateQueries());
assertEquals(0, myCaptureQueriesListener.countDeleteQueries()); assertEquals(0, myCaptureQueriesListener.countDeleteQueries());
@ -77,23 +96,28 @@ public class ConsumeFilesStepR4Test extends BaseJpaR4Test {
} }
@Test @ParameterizedTest
public void testAlreadyExisting_WithChanges() { @ValueSource(booleans = {false, true})
public void testAlreadyExisting_WithChanges(boolean partitionEnabled) {
// Setup // Setup
if (partitionEnabled) {
myPartitionSettings.setPartitioningEnabled(true);
myPartitionSettings.setIncludePartitionInSearchHashes(true);
addCreatePartition(1);
addCreatePartition(1);
}
Patient patient = new Patient(); Patient patient = new Patient();
patient.setId("A"); patient.setId("A");
patient.setActive(false); patient.setActive(false);
myPatientDao.update(patient); myPatientDao.update(patient, mySrd);
patient = new Patient(); patient = new Patient();
patient.setId("B"); patient.setId("B");
patient.setActive(true); patient.setActive(true);
myPatientDao.update(patient); myPatientDao.update(patient, mySrd);
List<IBaseResource> resources = new ArrayList<>(); List<IBaseResource> resources = new ArrayList<>();
patient = new Patient(); patient = new Patient();
patient.setId("Patient/A"); patient.setId("Patient/A");
patient.setActive(true); patient.setActive(true);
@ -108,20 +132,26 @@ public class ConsumeFilesStepR4Test extends BaseJpaR4Test {
myMemoryCacheService.invalidateAllCaches(); myMemoryCacheService.invalidateAllCaches();
myCaptureQueriesListener.clear(); myCaptureQueriesListener.clear();
mySvc.storeResources(resources); if (partitionEnabled) {
addReadPartition(1);
addReadPartition(1);
mySvc.storeResources(resources, myRequestPartitionId);
} else {
mySvc.storeResources(resources, null);
}
// Validate // Validate
assertEquals(4, myCaptureQueriesListener.logSelectQueries().size()); assertEquals(7, myCaptureQueriesListener.logSelectQueries().size());
assertEquals(2, myCaptureQueriesListener.logInsertQueries()); assertEquals(2, myCaptureQueriesListener.logInsertQueries());
assertEquals(4, myCaptureQueriesListener.logUpdateQueries()); assertEquals(4, myCaptureQueriesListener.logUpdateQueries());
assertEquals(0, myCaptureQueriesListener.countDeleteQueries()); assertEquals(0, myCaptureQueriesListener.countDeleteQueries());
assertEquals(1, myCaptureQueriesListener.countCommits()); assertEquals(1, myCaptureQueriesListener.countCommits());
assertEquals(0, myCaptureQueriesListener.countRollbacks()); assertEquals(0, myCaptureQueriesListener.countRollbacks());
patient = myPatientDao.read(new IdType("Patient/A")); patient = myPatientDao.read(new IdType("Patient/A"), mySrd);
assertTrue(patient.getActive()); assertTrue(patient.getActive());
patient = myPatientDao.read(new IdType("Patient/B")); patient = myPatientDao.read(new IdType("Patient/B"), mySrd);
assertFalse(patient.getActive()); assertFalse(patient.getActive());
} }
@ -146,15 +176,15 @@ public class ConsumeFilesStepR4Test extends BaseJpaR4Test {
// Execute // Execute
myCaptureQueriesListener.clear(); myCaptureQueriesListener.clear();
mySvc.storeResources(resources); mySvc.storeResources(resources, null);
// Validate // Validate
assertEquals(1, myCaptureQueriesListener.logSelectQueries().size()); assertEquals(1, myCaptureQueriesListener.logSelectQueries().size());
assertThat(myCaptureQueriesListener.getSelectQueries().get(0).getSql(true, false), assertThat(myCaptureQueriesListener.getSelectQueries().get(0).getSql(true, false),
either(containsString("forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='B' or forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='A'")) either(containsString("forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='B' and (forcedid0_.PARTITION_ID is null) or forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='A' and (forcedid0_.PARTITION_ID is null)"))
.or(containsString("forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='A' or forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='B'"))); .or(containsString("forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='A' and (forcedid0_.PARTITION_ID is null) or forcedid0_.RESOURCE_TYPE='Patient' and forcedid0_.FORCED_ID='B' and (forcedid0_.PARTITION_ID is null)")));
assertEquals(10, myCaptureQueriesListener.logInsertQueries()); assertEquals(52, myCaptureQueriesListener.logInsertQueries());
assertEquals(0, myCaptureQueriesListener.countUpdateQueries()); assertEquals(0, myCaptureQueriesListener.countUpdateQueries());
assertEquals(0, myCaptureQueriesListener.countDeleteQueries()); assertEquals(0, myCaptureQueriesListener.countDeleteQueries());
assertEquals(1, myCaptureQueriesListener.countCommits()); assertEquals(1, myCaptureQueriesListener.countCommits());
@ -189,7 +219,7 @@ public class ConsumeFilesStepR4Test extends BaseJpaR4Test {
myCaptureQueriesListener.clear(); myCaptureQueriesListener.clear();
try { try {
mySvc.storeResources(resources); mySvc.storeResources(resources, null);
fail(); fail();
} catch (JobExecutionFailedException e) { } catch (JobExecutionFailedException e) {

View File

@ -24,8 +24,10 @@ import ca.uhn.fhir.batch2.model.JobInstance;
import ca.uhn.fhir.batch2.model.JobInstanceStartRequest; import ca.uhn.fhir.batch2.model.JobInstanceStartRequest;
import ca.uhn.fhir.context.FhirContext; import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.i18n.Msg; import ca.uhn.fhir.i18n.Msg;
import ca.uhn.fhir.interceptor.model.RequestPartitionId;
import ca.uhn.fhir.jpa.batch.models.Batch2JobStartResponse; import ca.uhn.fhir.jpa.batch.models.Batch2JobStartResponse;
import ca.uhn.fhir.jpa.model.util.JpaConstants; import ca.uhn.fhir.jpa.model.util.JpaConstants;
import ca.uhn.fhir.jpa.partition.IRequestPartitionHelperSvc;
import ca.uhn.fhir.rest.annotation.Operation; import ca.uhn.fhir.rest.annotation.Operation;
import ca.uhn.fhir.rest.annotation.OperationParam; import ca.uhn.fhir.rest.annotation.OperationParam;
import ca.uhn.fhir.rest.annotation.ResourceParam; import ca.uhn.fhir.rest.annotation.ResourceParam;
@ -77,7 +79,8 @@ public class BulkDataImportProvider {
private IJobCoordinator myJobCoordinator; private IJobCoordinator myJobCoordinator;
@Autowired @Autowired
private FhirContext myFhirCtx; private FhirContext myFhirCtx;
@Autowired
private IRequestPartitionHelperSvc myRequestPartitionHelperService;
private volatile List<String> myResourceTypeOrder; private volatile List<String> myResourceTypeOrder;
/** /**
@ -95,6 +98,11 @@ public class BulkDataImportProvider {
myFhirCtx = theCtx; myFhirCtx = theCtx;
} }
public void setRequestPartitionHelperService(IRequestPartitionHelperSvc theRequestPartitionHelperSvc) {
myRequestPartitionHelperService = theRequestPartitionHelperSvc;
}
/** /**
* $import operation (Import by Manifest) * $import operation (Import by Manifest)
* <p> * <p>
@ -139,6 +147,12 @@ public class BulkDataImportProvider {
} }
} }
RequestPartitionId partitionId = myRequestPartitionHelperService.determineReadPartitionForRequest(theRequestDetails, null);
if (partitionId != null && !partitionId.isAllPartitions()) {
myRequestPartitionHelperService.validateHasPartitionPermissions(theRequestDetails, "Binary", partitionId);
jobParameters.setPartitionId(partitionId);
}
// Extract all the URLs and order them in the order that is least // Extract all the URLs and order them in the order that is least
// likely to result in conflict (e.g. Patients before Observations // likely to result in conflict (e.g. Patients before Observations
// since Observations can reference Patients but not vice versa) // since Observations can reference Patients but not vice versa)
@ -203,13 +217,22 @@ public class BulkDataImportProvider {
) throws IOException { ) throws IOException {
HttpServletResponse response = theRequestDetails.getServletResponse(); HttpServletResponse response = theRequestDetails.getServletResponse();
theRequestDetails.getServer().addHeadersToResponse(response); theRequestDetails.getServer().addHeadersToResponse(response);
JobInstance status = myJobCoordinator.getInstance(theJobId.getValueAsString()); JobInstance instance = myJobCoordinator.getInstance(theJobId.getValueAsString());
BulkImportJobParameters parameters = instance.getParameters(BulkImportJobParameters.class);
if (parameters != null && parameters.getPartitionId() != null) {
// Determine and validate permissions for partition (if needed)
RequestPartitionId partitionId = myRequestPartitionHelperService.determineReadPartitionForRequest(theRequestDetails, null);
myRequestPartitionHelperService.validateHasPartitionPermissions(theRequestDetails, "Binary", partitionId);
if (!partitionId.equals(parameters.getPartitionId())) {
throw new InvalidRequestException(Msg.code(2310) + "Invalid partition in request for Job ID " + theJobId);
}
}
IBaseOperationOutcome oo; IBaseOperationOutcome oo;
switch (status.getStatus()) { switch (instance.getStatus()) {
case QUEUED: { case QUEUED: {
response.setStatus(Constants.STATUS_HTTP_202_ACCEPTED); response.setStatus(Constants.STATUS_HTTP_202_ACCEPTED);
String msg = "Job was created at " + renderTime(status.getCreateTime()) + String msg = "Job was created at " + renderTime(instance.getCreateTime()) +
" and is in " + status.getStatus() + " and is in " + instance.getStatus() +
" state."; " state.";
response.addHeader(Constants.HEADER_X_PROGRESS, msg); response.addHeader(Constants.HEADER_X_PROGRESS, msg);
response.addHeader(Constants.HEADER_RETRY_AFTER, "120"); response.addHeader(Constants.HEADER_RETRY_AFTER, "120");
@ -218,12 +241,12 @@ public class BulkDataImportProvider {
} }
case IN_PROGRESS: { case IN_PROGRESS: {
response.setStatus(Constants.STATUS_HTTP_202_ACCEPTED); response.setStatus(Constants.STATUS_HTTP_202_ACCEPTED);
String msg = "Job was created at " + renderTime(status.getCreateTime()) + String msg = "Job was created at " + renderTime(instance.getCreateTime()) +
", started at " + renderTime(status.getStartTime()) + ", started at " + renderTime(instance.getStartTime()) +
" and is in " + status.getStatus() + " and is in " + instance.getStatus() +
" state. Current completion: " + " state. Current completion: " +
new DecimalFormat("0.0").format(100.0 * status.getProgress()) + new DecimalFormat("0.0").format(100.0 * instance.getProgress()) +
"% and ETA is " + status.getEstimatedTimeRemaining(); "% and ETA is " + instance.getEstimatedTimeRemaining();
response.addHeader(Constants.HEADER_X_PROGRESS, msg); response.addHeader(Constants.HEADER_X_PROGRESS, msg);
response.addHeader(Constants.HEADER_RETRY_AFTER, "120"); response.addHeader(Constants.HEADER_RETRY_AFTER, "120");
streamOperationOutcomeResponse(response, msg, "information"); streamOperationOutcomeResponse(response, msg, "information");
@ -238,8 +261,8 @@ public class BulkDataImportProvider {
case FAILED: case FAILED:
case ERRORED: { case ERRORED: {
response.setStatus(Constants.STATUS_HTTP_500_INTERNAL_ERROR); response.setStatus(Constants.STATUS_HTTP_500_INTERNAL_ERROR);
String msg = "Job is in " + status.getStatus() + " state with " + String msg = "Job is in " + instance.getStatus() + " state with " +
status.getErrorCount() + " error count. Last error: " + status.getErrorMessage(); instance.getErrorCount() + " error count. Last error: " + instance.getErrorMessage();
streamOperationOutcomeResponse(response, msg, "error"); streamOperationOutcomeResponse(response, msg, "error");
break; break;
} }

View File

@ -19,6 +19,7 @@
*/ */
package ca.uhn.fhir.batch2.jobs.imprt; package ca.uhn.fhir.batch2.jobs.imprt;
import ca.uhn.fhir.interceptor.model.RequestPartitionId;
import ca.uhn.fhir.model.api.IModelJson; import ca.uhn.fhir.model.api.IModelJson;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
@ -51,6 +52,10 @@ public class BulkImportJobParameters implements IModelJson {
@Nullable @Nullable
private Integer myMaxBatchResourceCount; private Integer myMaxBatchResourceCount;
@JsonProperty(value = "partitionId", required = false)
@Nullable
private RequestPartitionId myPartitionId;
public List<String> getNdJsonUrls() { public List<String> getNdJsonUrls() {
if (myNdJsonUrls == null) { if (myNdJsonUrls == null) {
myNdJsonUrls = new ArrayList<>(); myNdJsonUrls = new ArrayList<>();
@ -82,4 +87,14 @@ public class BulkImportJobParameters implements IModelJson {
getNdJsonUrls().add(theUrl); getNdJsonUrls().add(theUrl);
return this; return this;
} }
@Nullable
public RequestPartitionId getPartitionId() {
return myPartitionId;
}
public BulkImportJobParameters setPartitionId(@Nullable RequestPartitionId thePartitionId) {
myPartitionId = thePartitionId;
return this;
}
} }

View File

@ -33,10 +33,10 @@ import ca.uhn.fhir.jpa.api.dao.IFhirResourceDao;
import ca.uhn.fhir.jpa.api.dao.IFhirSystemDao; import ca.uhn.fhir.jpa.api.dao.IFhirSystemDao;
import ca.uhn.fhir.jpa.api.svc.IIdHelperService; import ca.uhn.fhir.jpa.api.svc.IIdHelperService;
import ca.uhn.fhir.jpa.dao.tx.HapiTransactionService; import ca.uhn.fhir.jpa.dao.tx.HapiTransactionService;
import ca.uhn.fhir.rest.api.server.SystemRequestDetails;
import ca.uhn.fhir.parser.DataFormatException; import ca.uhn.fhir.parser.DataFormatException;
import ca.uhn.fhir.parser.IParser; import ca.uhn.fhir.parser.IParser;
import ca.uhn.fhir.rest.api.server.RequestDetails; import ca.uhn.fhir.rest.api.server.RequestDetails;
import ca.uhn.fhir.rest.api.server.SystemRequestDetails;
import ca.uhn.fhir.rest.api.server.storage.IResourcePersistentId; import ca.uhn.fhir.rest.api.server.storage.IResourcePersistentId;
import ca.uhn.fhir.rest.api.server.storage.TransactionDetails; import ca.uhn.fhir.rest.api.server.storage.TransactionDetails;
import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException; import ca.uhn.fhir.rest.server.exceptions.InvalidRequestException;
@ -96,18 +96,23 @@ public class ConsumeFilesStep implements ILastJobStepWorker<BulkImportJobParamet
ourLog.info("Bulk loading {} resources from source {}", resources.size(), sourceName); ourLog.info("Bulk loading {} resources from source {}", resources.size(), sourceName);
storeResources(resources); storeResources(resources, theStepExecutionDetails.getParameters().getPartitionId());
return new RunOutcome(resources.size()); return new RunOutcome(resources.size());
} }
public void storeResources(List<IBaseResource> resources) { public void storeResources(List<IBaseResource> resources, RequestPartitionId thePartitionId) {
RequestDetails requestDetails = new SystemRequestDetails(); SystemRequestDetails requestDetails = new SystemRequestDetails();
if (thePartitionId == null) {
requestDetails.setRequestPartitionId(RequestPartitionId.defaultPartition());
} else {
requestDetails.setRequestPartitionId(thePartitionId);
}
TransactionDetails transactionDetails = new TransactionDetails(); TransactionDetails transactionDetails = new TransactionDetails();
myHapiTransactionService.execute(requestDetails, transactionDetails, tx -> storeResourcesInsideTransaction(resources, requestDetails, transactionDetails)); myHapiTransactionService.execute(requestDetails, transactionDetails, tx -> storeResourcesInsideTransaction(resources, requestDetails, transactionDetails));
} }
private Void storeResourcesInsideTransaction(List<IBaseResource> theResources, RequestDetails theRequestDetails, TransactionDetails theTransactionDetails) { private Void storeResourcesInsideTransaction(List<IBaseResource> theResources, SystemRequestDetails theRequestDetails, TransactionDetails theTransactionDetails) {
Map<IIdType, IBaseResource> ids = new HashMap<>(); Map<IIdType, IBaseResource> ids = new HashMap<>();
for (IBaseResource next : theResources) { for (IBaseResource next : theResources) {
if (!next.getIdElement().hasIdPart()) { if (!next.getIdElement().hasIdPart()) {
@ -122,7 +127,7 @@ public class ConsumeFilesStep implements ILastJobStepWorker<BulkImportJobParamet
} }
List<IIdType> idsList = new ArrayList<>(ids.keySet()); List<IIdType> idsList = new ArrayList<>(ids.keySet());
List<IResourcePersistentId> resolvedIds = myIdHelperService.resolveResourcePersistentIdsWithCache(RequestPartitionId.allPartitions(), idsList, true); List<IResourcePersistentId> resolvedIds = myIdHelperService.resolveResourcePersistentIdsWithCache(theRequestDetails.getRequestPartitionId(), idsList, true);
for (IResourcePersistentId next : resolvedIds) { for (IResourcePersistentId next : resolvedIds) {
IIdType resId = next.getAssociatedResourceId(); IIdType resId = next.getAssociatedResourceId();
theTransactionDetails.addResolvedResourceId(resId, next); theTransactionDetails.addResolvedResourceId(resId, next);

View File

@ -5,10 +5,16 @@ import ca.uhn.fhir.batch2.model.JobInstance;
import ca.uhn.fhir.batch2.model.JobInstanceStartRequest; import ca.uhn.fhir.batch2.model.JobInstanceStartRequest;
import ca.uhn.fhir.batch2.model.StatusEnum; import ca.uhn.fhir.batch2.model.StatusEnum;
import ca.uhn.fhir.context.FhirContext; import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.interceptor.model.ReadPartitionIdRequestDetails;
import ca.uhn.fhir.interceptor.model.RequestPartitionId;
import ca.uhn.fhir.jpa.batch.models.Batch2JobStartResponse; import ca.uhn.fhir.jpa.batch.models.Batch2JobStartResponse;
import ca.uhn.fhir.jpa.model.util.JpaConstants; import ca.uhn.fhir.jpa.model.util.JpaConstants;
import ca.uhn.fhir.jpa.partition.IRequestPartitionHelperSvc;
import ca.uhn.fhir.rest.api.Constants; import ca.uhn.fhir.rest.api.Constants;
import ca.uhn.fhir.rest.api.server.RequestDetails;
import ca.uhn.fhir.rest.client.apache.ResourceEntity; import ca.uhn.fhir.rest.client.apache.ResourceEntity;
import ca.uhn.fhir.rest.server.exceptions.ForbiddenOperationException;
import ca.uhn.fhir.rest.server.tenant.UrlBaseTenantIdentificationStrategy;
import ca.uhn.fhir.test.utilities.HttpClientExtension; import ca.uhn.fhir.test.utilities.HttpClientExtension;
import ca.uhn.fhir.test.utilities.server.RestfulServerExtension; import ca.uhn.fhir.test.utilities.server.RestfulServerExtension;
import com.google.common.base.Charsets; import com.google.common.base.Charsets;
@ -16,6 +22,7 @@ import org.apache.commons.io.IOUtils;
import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPost;
import org.hl7.fhir.instance.model.api.IBaseResource;
import org.hl7.fhir.r4.model.CodeType; import org.hl7.fhir.r4.model.CodeType;
import org.hl7.fhir.r4.model.InstantType; import org.hl7.fhir.r4.model.InstantType;
import org.hl7.fhir.r4.model.OperationOutcome; import org.hl7.fhir.r4.model.OperationOutcome;
@ -23,6 +30,8 @@ import org.hl7.fhir.r4.model.Parameters;
import org.hl7.fhir.r4.model.StringType; import org.hl7.fhir.r4.model.StringType;
import org.hl7.fhir.r4.model.UriType; import org.hl7.fhir.r4.model.UriType;
import org.hl7.fhir.r4.model.UrlType; import org.hl7.fhir.r4.model.UrlType;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.MethodOrderer; import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -30,10 +39,13 @@ import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -42,11 +54,14 @@ import javax.annotation.Nonnull;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Date; import java.util.Date;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -54,7 +69,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@TestMethodOrder(value=MethodOrderer.MethodName.class) @TestMethodOrder(value = MethodOrderer.MethodName.class)
public class BulkDataImportProviderTest { public class BulkDataImportProviderTest {
private static final Logger ourLog = LoggerFactory.getLogger(BulkDataImportProviderTest.class); private static final Logger ourLog = LoggerFactory.getLogger(BulkDataImportProviderTest.class);
private static final String A_JOB_ID = "0000000-A1A1A1A1"; private static final String A_JOB_ID = "0000000-A1A1A1A1";
@ -68,18 +83,35 @@ public class BulkDataImportProviderTest {
private IJobCoordinator myJobCoordinator; private IJobCoordinator myJobCoordinator;
@Captor @Captor
private ArgumentCaptor<JobInstanceStartRequest> myStartRequestCaptor; private ArgumentCaptor<JobInstanceStartRequest> myStartRequestCaptor;
@Spy
private IRequestPartitionHelperSvc myRequestPartitionHelperSvc = new MyRequestPartitionHelperSvc();
private final RequestPartitionId myRequestPartitionId = RequestPartitionId.fromPartitionIdAndName(123, "Partition-A");
private final String myPartitionName = "Partition-A";
@BeforeEach @BeforeEach
public void beforeEach() { public void beforeEach() {
myProvider.setFhirContext(myCtx); myProvider.setFhirContext(myCtx);
myProvider.setJobCoordinator(myJobCoordinator); myProvider.setJobCoordinator(myJobCoordinator);
myProvider.setRequestPartitionHelperService(myRequestPartitionHelperSvc);
}
public void enablePartitioning() {
myRestfulServerExtension.getRestfulServer().setTenantIdentificationStrategy(new UrlBaseTenantIdentificationStrategy());
}
private static Stream<Arguments> provideParameters() {
return Stream.of(
Arguments.of(UrlType.class, false),
Arguments.of(UriType.class, false),
Arguments.of(UrlType.class, true),
Arguments.of(UriType.class, true)
);
} }
@ParameterizedTest @ParameterizedTest
@ValueSource(classes = {UrlType.class, UriType.class}) @MethodSource("provideParameters")
public void testStart_Success(Class<?> type) throws IOException { public void testStartWithPartitioning_Success(Class<?> type, boolean partitionEnabled) throws IOException {
// Setup // Setup
Parameters input = createRequest(type); Parameters input = createRequest(type);
ourLog.debug("Input: {}", myCtx.newJsonParser().setPrettyPrint(true).encodeResourceToString(input)); ourLog.debug("Input: {}", myCtx.newJsonParser().setPrettyPrint(true).encodeResourceToString(input));
@ -89,7 +121,14 @@ public class BulkDataImportProviderTest {
when(myJobCoordinator.startInstance(any())) when(myJobCoordinator.startInstance(any()))
.thenReturn(startResponse); .thenReturn(startResponse);
String url = myRestfulServerExtension.getBaseUrl() + "/" + JpaConstants.OPERATION_IMPORT; String requestUrl;
if (partitionEnabled) {
enablePartitioning();
requestUrl = myRestfulServerExtension.getBaseUrl() + "/" + myPartitionName + "/";
} else {
requestUrl = myRestfulServerExtension.getBaseUrl() + "/";
}
String url = requestUrl + JpaConstants.OPERATION_IMPORT;
HttpPost post = new HttpPost(url); HttpPost post = new HttpPost(url);
post.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC); post.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC);
post.setEntity(new ResourceEntity(myCtx, input)); post.setEntity(new ResourceEntity(myCtx, input));
@ -107,14 +146,14 @@ public class BulkDataImportProviderTest {
OperationOutcome oo = myCtx.newJsonParser().parseResource(OperationOutcome.class, resp); OperationOutcome oo = myCtx.newJsonParser().parseResource(OperationOutcome.class, resp);
assertEquals("Bulk import job has been submitted with ID: " + jobId, oo.getIssue().get(0).getDiagnostics()); assertEquals("Bulk import job has been submitted with ID: " + jobId, oo.getIssue().get(0).getDiagnostics());
assertEquals("Use the following URL to poll for job status: http://localhost:" + myRestfulServerExtension.getPort() + "/$import-poll-status?_jobId=" + jobId, oo.getIssue().get(1).getDiagnostics()); assertEquals("Use the following URL to poll for job status: " + requestUrl + "$import-poll-status?_jobId=" + jobId, oo.getIssue().get(1).getDiagnostics());
} }
verify(myJobCoordinator, times(1)).startInstance(myStartRequestCaptor.capture()); verify(myJobCoordinator, times(1)).startInstance(myStartRequestCaptor.capture());
JobInstanceStartRequest startRequest = myStartRequestCaptor.getValue(); JobInstanceStartRequest startRequest = myStartRequestCaptor.getValue();
ourLog.info("Parameters: {}", startRequest.getParameters()); ourLog.info("Parameters: {}", startRequest.getParameters());
assertEquals("{\"ndJsonUrls\":[\"http://example.com/Patient\",\"http://example.com/Observation\"],\"maxBatchResourceCount\":500}", startRequest.getParameters()); assertTrue(startRequest.getParameters().startsWith("{\"ndJsonUrls\":[\"http://example.com/Patient\",\"http://example.com/Observation\"],\"maxBatchResourceCount\":500"));
} }
@Test @Test
@ -172,7 +211,8 @@ public class BulkDataImportProviderTest {
} }
@Nonnull Parameters createRequest() { @Nonnull
Parameters createRequest() {
return createRequest(UriType.class); return createRequest(UriType.class);
} }
@ -242,8 +282,9 @@ public class BulkDataImportProviderTest {
} }
} }
@Test @ParameterizedTest
public void testPollForStatus_COMPLETE() throws IOException { @ValueSource(booleans = {false, true})
public void testPollForStatus_COMPLETE(boolean partitionEnabled) throws IOException {
JobInstance jobInfo = new JobInstance() JobInstance jobInfo = new JobInstance()
.setStatus(StatusEnum.COMPLETED) .setStatus(StatusEnum.COMPLETED)
.setCreateTime(parseDate("2022-01-01T12:00:00-04:00")) .setCreateTime(parseDate("2022-01-01T12:00:00-04:00"))
@ -251,7 +292,16 @@ public class BulkDataImportProviderTest {
.setEndTime(parseDate("2022-01-01T12:10:00-04:00")); .setEndTime(parseDate("2022-01-01T12:10:00-04:00"));
when(myJobCoordinator.getInstance(eq(A_JOB_ID))).thenReturn(jobInfo); when(myJobCoordinator.getInstance(eq(A_JOB_ID))).thenReturn(jobInfo);
String url = "http://localhost:" + myRestfulServerExtension.getPort() + "/" + JpaConstants.OPERATION_IMPORT_POLL_STATUS + "?" + String requestUrl;
if (partitionEnabled) {
enablePartitioning();
requestUrl = myRestfulServerExtension.getBaseUrl() + "/" + myPartitionName + "/";
BulkImportJobParameters jobParameters = new BulkImportJobParameters().setPartitionId(myRequestPartitionId);
jobInfo.setParameters(jobParameters);
} else {
requestUrl = myRestfulServerExtension.getBaseUrl() + "/";
}
String url = requestUrl + JpaConstants.OPERATION_IMPORT_POLL_STATUS + "?" +
JpaConstants.PARAM_IMPORT_POLL_STATUS_JOB_ID + "=" + A_JOB_ID; JpaConstants.PARAM_IMPORT_POLL_STATUS_JOB_ID + "=" + A_JOB_ID;
HttpGet get = new HttpGet(url); HttpGet get = new HttpGet(url);
get.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC); get.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC);
@ -290,6 +340,102 @@ public class BulkDataImportProviderTest {
} }
} }
@Test
public void testFailBulkImportRequest_PartitionedWithoutPermissions() throws IOException {
// setup
enablePartitioning();
Parameters input = createRequest();
// test
String url = myRestfulServerExtension.getBaseUrl() + "/Partition-B/" + JpaConstants.OPERATION_IMPORT;
HttpPost post = new HttpPost(url);
post.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC);
post.setEntity(new ResourceEntity(myCtx, input));
ourLog.info("Request: {}", url);
try (CloseableHttpResponse response = myClient.getClient().execute(post)) {
ourLog.info("Response: {}", response);
String resp = IOUtils.toString(response.getEntity().getContent(), StandardCharsets.UTF_8);
ourLog.info(resp);
// Verify
assertEquals(403, response.getStatusLine().getStatusCode());
assertEquals("Forbidden", response.getStatusLine().getReasonPhrase());
}
}
@Test
public void testFailBulkImportPollStatus_PartitionedWithoutPermissions() throws IOException {
// setup
enablePartitioning();
JobInstance jobInfo = new JobInstance()
.setStatus(StatusEnum.COMPLETED)
.setCreateTime(parseDate("2022-01-01T12:00:00-04:00"))
.setStartTime(parseDate("2022-01-01T12:10:00-04:00"))
.setEndTime(parseDate("2022-01-01T12:10:00-04:00"));
when(myJobCoordinator.getInstance(eq(A_JOB_ID))).thenReturn(jobInfo);
BulkImportJobParameters jobParameters = new BulkImportJobParameters().setPartitionId(myRequestPartitionId);
jobInfo.setParameters(jobParameters);
// test
String url = myRestfulServerExtension.getBaseUrl() + "/Partition-B/" + JpaConstants.OPERATION_IMPORT_POLL_STATUS + "?" +
JpaConstants.PARAM_IMPORT_POLL_STATUS_JOB_ID + "=" + A_JOB_ID;
HttpGet get = new HttpGet(url);
get.addHeader(Constants.HEADER_PREFER, Constants.HEADER_PREFER_RESPOND_ASYNC);
try (CloseableHttpResponse response = myClient.execute(get)) {
ourLog.info("Response: {}", response.toString());
// Verify
assertEquals(403, response.getStatusLine().getStatusCode());
assertEquals("Forbidden", response.getStatusLine().getReasonPhrase());
}
}
private class MyRequestPartitionHelperSvc implements IRequestPartitionHelperSvc {
@Nonnull
@Override
public RequestPartitionId determineReadPartitionForRequest(@Nullable RequestDetails theRequest, ReadPartitionIdRequestDetails theDetails) {
assert theRequest != null;
if (myPartitionName.equals(theRequest.getTenantId())) {
return myRequestPartitionId;
} else {
return RequestPartitionId.fromPartitionName(theRequest.getTenantId());
}
}
public void validateHasPartitionPermissions(RequestDetails theRequest, String theResourceType, RequestPartitionId theRequestPartitionId) {
if (!myPartitionName.equals(theRequest.getTenantId()) && theRequest.getTenantId() != null) {
throw new ForbiddenOperationException("User does not have access to resources on the requested partition");
}
}
@Override
public RequestPartitionId determineGenericPartitionForRequest(RequestDetails theRequestDetails) {
return null;
}
@NotNull
@Override
public RequestPartitionId determineCreatePartitionForRequest(@Nullable RequestDetails theRequest, @NotNull IBaseResource theResource, @NotNull String theResourceType) {
return null;
}
@NotNull
@Override
public Set<Integer> toReadPartitions(@NotNull RequestPartitionId theRequestPartitionId) {
return null;
}
@Override
public boolean isResourcePartitionable(String theResourceType) {
return false;
}
}
private Date parseDate(String theString) { private Date parseDate(String theString) {
return new InstantType(theString).getValue(); return new InstantType(theString).getValue();
} }

View File

@ -300,7 +300,6 @@ public class JobInstance extends JobInstanceStartRequest implements IModelJson,
return this; return this;
} }
public void setJobDefinition(JobDefinition<?> theJobDefinition) { public void setJobDefinition(JobDefinition<?> theJobDefinition) {
setJobDefinitionId(theJobDefinition.getJobDefinitionId()); setJobDefinitionId(theJobDefinition.getJobDefinitionId());
setJobDefinitionVersion(theJobDefinition.getJobDefinitionVersion()); setJobDefinitionVersion(theJobDefinition.getJobDefinitionVersion());