Websocket improvements (#1272)

* subscription loader now only pulls active subscriptions
only allow websocket connections to subscriptions of type websocket

* Added a "flag for deletion" to ActiveSubscription in the SubscriptionRegistry to handle the race condition of a scheduled sync overlapping with a subscription creation.  We could have used a package-scoped semaphore or a pre-remove FHIR read, but this seemed like the safest, simplest and most performant way to handle it.

* ActiveSubscriptionCacheTest

* WebsocketConnectionValidatorTest

* fix compile error in jpa example
This commit is contained in:
Ken Stevens 2019-04-12 08:36:49 -04:00 committed by GitHub
parent ca8b6acdf9
commit bb98ded1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 385 additions and 72 deletions

View File

@ -183,8 +183,8 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.6</source>
<target>1.6</target>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>

View File

@ -20,7 +20,7 @@ package ca.uhn.fhir.jpa.config;
* #L%
*/
import ca.uhn.fhir.jpa.subscription.module.subscriber.SubscriptionWebsocketHandler;
import ca.uhn.fhir.jpa.subscription.module.subscriber.websocket.SubscriptionWebsocketHandler;
import org.springframework.beans.factory.annotation.Autowire;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

View File

@ -4,7 +4,6 @@ import ca.uhn.fhir.jpa.util.CircularQueueCaptureQueriesListener;
import ca.uhn.fhir.rest.server.interceptor.RequestValidatingInterceptor;
import ca.uhn.fhir.validation.ResultSeverityEnum;
import net.ttddyy.dsproxy.listener.SingleQueryCountHolder;
import net.ttddyy.dsproxy.listener.logging.SLF4JLogLevel;
import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder;
import org.apache.commons.dbcp2.BasicDataSource;
import org.springframework.context.annotation.Bean;
@ -103,7 +102,7 @@ public class TestR4Config extends BaseJavaConfigR4 {
DataSource dataSource = ProxyDataSourceBuilder
.create(retVal)
.logQueryBySlf4j(SLF4JLogLevel.INFO, "SQL")
// .logQueryBySlf4j(SLF4JLogLevel.INFO, "SQL")
// .logSlowQueryBySlf4j(10, TimeUnit.SECONDS)
// .countQuery(new ThreadQueryCountHolder())
.beforeQuery(new BlockLargeNumbersOfParamsListener())

View File

@ -205,7 +205,7 @@ public abstract class BaseResourceProviderR4Test extends BaseJpaR4Test {
fail("Failed to init subscriptions");
}
try {
mySubscriptionLoader.syncSubscriptions();
mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
break;
} catch (ResourceVersionConflictException e) {
Thread.sleep(250);

View File

@ -57,7 +57,7 @@ public class RestHookActivatesPreExistingSubscriptionsR4Test extends BaseResourc
@Before
public void beforeSetSubscriptionActivatingInterceptor() {
SubscriptionActivatingInterceptor.setWaitForSubscriptionActivationSynchronouslyForUnitTest(true);
mySubscriptionLoader.syncSubscriptions();
mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
}
@ -109,7 +109,7 @@ public class RestHookActivatesPreExistingSubscriptionsR4Test extends BaseResourc
createSubscription(criteria2, payload, ourListenerServerBase);
mySubscriptionTestUtil.registerRestHookInterceptor();
mySubscriptionLoader.syncSubscriptions();
mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
sendObservation(code, "SNOMED-CT");

View File

@ -35,6 +35,7 @@ public class RestHookTestR4Test extends BaseSubscriptionsR4Test {
@After
public void cleanupStoppableSubscriptionDeliveringRestHookSubscriber() {
ourLog.info("@After");
myStoppableSubscriptionDeliveringRestHookSubscriber.setCountDownLatch(null);
myStoppableSubscriptionDeliveringRestHookSubscriber.unPause();
}

View File

@ -69,7 +69,7 @@ public class RestHookTestWithInterceptorRegisteredToDaoConfigDstu2Test extends B
ourCreatedObservations.clear();
ourUpdatedObservations.clear();
mySubscriptionLoader.syncSubscriptions();
mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
}
private void waitForQueueToDrain() throws InterruptedException {

View File

@ -0,0 +1,14 @@
package ca.uhn.fhir.jpa.subscription.module;
import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import java.util.List;
public interface IPointcutLatch {
void clear();
void setExpectedCount(int count);
List<HookParams> awaitExpected() throws InterruptedException;
}

View File

@ -39,6 +39,7 @@ public class ActiveSubscription {
private CanonicalSubscription mySubscription;
private final SubscribableChannel mySubscribableChannel;
private final Collection<MessageHandler> myDeliveryHandlerSet = new HashSet<>();
private boolean flagForDeletion;
public ActiveSubscription(CanonicalSubscription theSubscription, SubscribableChannel theSubscribableChannel) {
mySubscription = theSubscription;
@ -94,4 +95,12 @@ public class ActiveSubscription {
public void setSubscription(CanonicalSubscription theCanonicalizedSubscription) {
mySubscription = theCanonicalizedSubscription;
}
public boolean isFlagForDeletion() {
return flagForDeletion;
}
public void setFlagForDeletion(boolean theFlagForDeletion) {
flagForDeletion = theFlagForDeletion;
}
}

View File

@ -21,6 +21,8 @@ package ca.uhn.fhir.jpa.subscription.module.cache;
*/
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
@ -29,7 +31,7 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
class ActiveSubscriptionCache {
private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(ActiveSubscriptionCache.class);
private static final Logger ourLog = LoggerFactory.getLogger(ActiveSubscriptionCache.class);
private final Map<String, ActiveSubscription> myCache = new ConcurrentHashMap<>();
@ -63,9 +65,17 @@ class ActiveSubscriptionCache {
public void unregisterAllSubscriptionsNotInCollection(Collection<String> theAllIds) {
for (String next : new ArrayList<>(myCache.keySet())) {
if (!theAllIds.contains(next)) {
ourLog.info("Unregistering Subscription/{}", next);
remove(next);
ActiveSubscription activeSubscription = myCache.get(next);
if (theAllIds.contains(next)) {
// In case we got a false positive from a race condition on a previous sync, unset the flag.
activeSubscription.setFlagForDeletion(false);
} else {
if (activeSubscription.isFlagForDeletion()) {
ourLog.info("Unregistering Subscription/{}", next);
remove(next);
} else {
activeSubscription.setFlagForDeletion(true);
}
}
}
}

View File

@ -74,7 +74,10 @@ public class SubscriptionLoader {
@VisibleForTesting
public int doSyncSubscriptionsForUnitTest() {
return doSyncSubscriptionsWithRetry();
// Two passes for delete flag to take effect
int first = doSyncSubscriptionsWithRetry();
int second = doSyncSubscriptionsWithRetry();
return first + second;
}
synchronized int doSyncSubscriptionsWithRetry() {
@ -87,6 +90,10 @@ public class SubscriptionLoader {
ourLog.debug("Starting sync subscriptions");
SearchParameterMap map = new SearchParameterMap();
map.add(Subscription.SP_STATUS, new TokenOrListParam()
// TODO KHS Ideally we should only be pulling ACTIVE subscriptions here, but this class is overloaded so that
// the @Scheduled task also activates requested subscriptions if their type was enabled after they were requested
// There should be a separate @Scheduled task that looks for requested subscriptions that need to be activated
// independent of the registry loading process.
.addOr(new TokenParam(null, Subscription.SubscriptionStatus.REQUESTED.toCode()))
.addOr(new TokenParam(null, Subscription.SubscriptionStatus.ACTIVE.toCode())));
map.setLoadSynchronousUpTo(SubscriptionConstants.MAX_SUBSCRIPTION_RESULTS);

View File

@ -101,6 +101,7 @@ public class SubscriptionRegistry {
deliveryHandler = Optional.empty();
}
ourLog.info("Registering active subscription {}", theSubscription.getIdElement().toUnqualified().getValue());
ActiveSubscription activeSubscription = new ActiveSubscription(canonicalized, deliveryChannel);
deliveryHandler.ifPresent(activeSubscription::register);
@ -115,11 +116,16 @@ public class SubscriptionRegistry {
public void unregisterSubscription(IIdType theId) {
Validate.notNull(theId);
String subscriptionId = theId.getIdPart();
ourLog.info("Unregistering active subscription {}", theId.toUnqualified().getValue());
myActiveSubscriptionCache.remove(subscriptionId);
}
@PreDestroy
public void unregisterAllSubscriptions() {
// Once to set flag
unregisterAllSubscriptionsNotInCollection(Collections.emptyList());
// Twice to remove
unregisterAllSubscriptionsNotInCollection(Collections.emptyList());
}
@ -143,8 +149,6 @@ public class SubscriptionRegistry {
return true;
}
unregisterSubscription(theSubscription.getIdElement());
} else {
ourLog.info("Registering active subscription {}", theSubscription.getIdElement().toUnqualified().getValue());
}
if (Subscription.SubscriptionStatus.ACTIVE.equals(newSubscription.getStatus())) {
registerSubscription(theSubscription.getIdElement(), theSubscription);

View File

@ -1,4 +1,4 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber;
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
/*
* #%L
@ -22,10 +22,11 @@ package ca.uhn.fhir.jpa.subscription.module.subscriber;
import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry;
import ca.uhn.fhir.rest.server.exceptions.ResourceNotFoundException;
import ca.uhn.fhir.jpa.subscription.module.subscriber.ResourceDeliveryMessage;
import org.hl7.fhir.instance.model.api.IIdType;
import org.hl7.fhir.r4.model.IdType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
@ -41,9 +42,9 @@ import javax.annotation.PreDestroy;
import java.io.IOException;
public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler {
private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
private static Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
@Autowired
protected SubscriptionRegistry mySubscriptionRegistry;
protected WebsocketConnectionValidator myWebsocketConnectionValidator;
@Autowired
private FhirContext myCtx;
@ -160,34 +161,18 @@ public class SubscriptionWebsocketHandler extends TextWebSocketHandler implement
private IIdType bindSimple(WebSocketSession theSession, String theBindString) {
IdType id = new IdType(theBindString);
if (!id.hasIdPart() || !id.isIdPartValid()) {
WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id);
if (!response.isValid()) {
try {
String message = "Invalid bind request - No ID included";
ourLog.warn(message);
theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), message));
ourLog.warn(response.getMessage());
theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage()));
} catch (IOException e) {
handleFailure(e);
}
return null;
}
if (id.hasResourceType() == false) {
id = id.withResourceType("Subscription");
}
try {
ActiveSubscription activeSubscription = mySubscriptionRegistry.get(id.getIdPart());
myState = new BoundStaticSubscipriptionState( theSession, activeSubscription);
} catch (ResourceNotFoundException e) {
try {
String message = "Invalid bind request - Unknown subscription: " + id.getValue();
ourLog.warn(message);
theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), message));
} catch (IOException e1) {
handleFailure(e);
}
return null;
}
myState = new BoundStaticSubscipriptionState(theSession, response.getActiveSubscription());
return id;
}

View File

@ -0,0 +1,42 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscriptionChannelType;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry;
import com.sun.istack.NotNull;
import org.hl7.fhir.r4.model.IdType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class WebsocketConnectionValidator {
private static Logger ourLog = LoggerFactory.getLogger(WebsocketConnectionValidator.class);
@Autowired
SubscriptionRegistry mySubscriptionRegistry;
public WebsocketValidationResponse validate(@NotNull IdType id) {
if (!id.hasIdPart() || !id.isIdPartValid()) {
return WebsocketValidationResponse.INVALID_RESPONSE("Invalid bind request - No ID included: " + id.getValue());
}
if (!id.hasResourceType()) {
id = id.withResourceType("Subscription");
}
ActiveSubscription activeSubscription = mySubscriptionRegistry.get(id.getIdPart());
if (activeSubscription == null) {
return WebsocketValidationResponse.INVALID_RESPONSE("Invalid bind request - Unknown subscription: " + id.getValue());
}
if (activeSubscription.getSubscription().getChannelType() != CanonicalSubscriptionChannelType.WEBSOCKET) {
return WebsocketValidationResponse.INVALID_RESPONSE("Subscription " + id.getValue() + " is not a " + CanonicalSubscriptionChannelType.WEBSOCKET + " subscription");
}
return WebsocketValidationResponse.VALID_RESPONSE(activeSubscription);
}
}

View File

@ -0,0 +1,35 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
public class WebsocketValidationResponse {
private final boolean myValid;
private final String myMessage;
private final ActiveSubscription myActiveSubscription;
public static WebsocketValidationResponse INVALID_RESPONSE(String theMessage) {
return new WebsocketValidationResponse(false, theMessage, null);
}
public static WebsocketValidationResponse VALID_RESPONSE(ActiveSubscription theActiveSubscription) {
return new WebsocketValidationResponse(true, null, theActiveSubscription);
}
private WebsocketValidationResponse(boolean theValid, String theMessage, ActiveSubscription theActiveSubscription) {
myValid = theValid;
myMessage = theMessage;
myActiveSubscription = theActiveSubscription;
}
public boolean isValid() {
return myValid;
}
public String getMessage() {
return myMessage;
}
public ActiveSubscription getActiveSubscription() {
return myActiveSubscription;
}
}

View File

@ -1,8 +1,10 @@
package ca.uhn.fhir.jpa.subscription.module;
import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import ca.uhn.fhir.jpa.model.interceptor.api.IAnonymousInterceptor;
import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -13,9 +15,11 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class PointcutLatch implements IAnonymousInterceptor {
// TODO KHS copy this version over to hapi-fhir
public class PointcutLatch implements IAnonymousInterceptor, IPointcutLatch {
private static final Logger ourLog = LoggerFactory.getLogger(PointcutLatch.class);
private static final int DEFAULT_TIMEOUT_SECONDS = 10;
private static final FhirObjectPrinter ourFhirObjectToStringMapper = new FhirObjectPrinter();
@ -23,8 +27,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
private final String name;
private CountDownLatch myCountdownLatch;
private AtomicReference<String> myFailure;
private AtomicReference<List<String>> myFailures;
private AtomicReference<List<HookParams>> myCalledWith;
private int myInitialCount;
private Pointcut myPointcut;
public PointcutLatch(Pointcut thePointcut) {
@ -36,6 +41,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
this.name = theName;
}
@Override
public void setExpectedCount(int count) {
if (myCountdownLatch != null) {
throw new PointcutLatchException("setExpectedCount() called before previous awaitExpected() completed.");
@ -45,14 +51,15 @@ public class PointcutLatch implements IAnonymousInterceptor {
}
private void createLatch(int count) {
myFailure = new AtomicReference<>();
myFailures = new AtomicReference<>(new ArrayList<>());
myCalledWith = new AtomicReference<>(new ArrayList<>());
myCountdownLatch = new CountDownLatch(count);
myInitialCount = count;
}
private void setFailure(String failure) {
if (myFailure != null) {
myFailure.set(failure);
private void addFailure(String failure) {
if (myFailures != null) {
myFailures.get().add(failure);
} else {
throw new PointcutLatchException("trying to set failure on latch that hasn't been created: " + failure);
}
@ -62,6 +69,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
return name + " " + this.getClass().getSimpleName();
}
@Override
public List<HookParams> awaitExpected() throws InterruptedException {
return awaitExpectedWithTimeout(DEFAULT_TIMEOUT_SECONDS);
}
@ -70,10 +78,19 @@ public class PointcutLatch implements IAnonymousInterceptor {
List<HookParams> retval = myCalledWith.get();
try {
assertNotNull(getName() + " awaitExpected() called before setExpected() called.", myCountdownLatch);
assertTrue(getName() + " timed out waiting " + timeoutSecond + " seconds for latch to be triggered.", myCountdownLatch.await(timeoutSecond, TimeUnit.SECONDS));
if (!myCountdownLatch.await(timeoutSecond, TimeUnit.SECONDS)) {
throw new AssertionError(getName() + " timed out waiting " + timeoutSecond + " seconds for latch to countdown from " + myInitialCount + " to 0. Is " + myCountdownLatch.getCount() + ".");
}
if (myFailure.get() != null) {
String error = getName() + ": " + myFailure.get();
List<String> failures = myFailures.get();
String error = getName();
if (failures != null && failures.size() > 0) {
if (failures.size() > 1) {
error += " ERRORS: \n";
} else {
error += " ERROR: ";
}
error += failures.stream().collect(Collectors.joining("\n"));
error += "\nLatch called with values: " + myCalledWithString();
throw new AssertionError(error);
}
@ -84,10 +101,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
return retval;
}
public void expectNothing() {
clear();
}
@Override
public void clear() {
myCountdownLatch = null;
}
@ -109,9 +123,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
@Override
public void invoke(Pointcut thePointcut, HookParams theArgs) {
if (myCountdownLatch == null) {
throw new PointcutLatchException("invoke() called before setExpectedCount() called.", theArgs);
throw new PointcutLatchException("invoke() called outside of setExpectedCount() .. awaitExpected(). Probably got more invocations than expected or clear() was called before invoke() arrived.", theArgs);
} else if (myCountdownLatch.getCount() <= 0) {
setFailure("invoke() called " + (1 - myCountdownLatch.getCount()) + " more times than expected.");
addFailure("invoke() called when countdown was zero.");
}
if (myCalledWith.get() != null) {
@ -119,6 +133,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
}
ourLog.info("Called {} {} with {}", name, myCountdownLatch, hookParamsToString(theArgs));
if (myCountdownLatch == null) {
throw new PointcutLatchException("invoke() called outside of setExpectedCount() .. awaitExpected(). Probably got more invocations than expected or clear() was called before invoke() arrived.", theArgs);
}
myCountdownLatch.countDown();
}
@ -139,4 +156,27 @@ public class PointcutLatch implements IAnonymousInterceptor {
private static String hookParamsToString(HookParams hookParams) {
return hookParams.values().stream().map(ourFhirObjectToStringMapper).collect(Collectors.joining(", "));
}
@Override
public String toString() {
return new ToStringBuilder(this)
.append("name", name)
.append("myCountdownLatch", myCountdownLatch)
// .append("myFailures", myFailures)
// .append("myCalledWith", myCalledWith)
.append("myInitialCount", myInitialCount)
.toString();
}
public Object getLatchInvocationParameter() {
return getLatchInvocationParameter(myCalledWith.get());
}
public static Object getLatchInvocationParameter(List<HookParams> theHookParams) {
assertNotNull(theHookParams);
assertEquals("Expected Pointcut to be invoked 1 time", 1, theHookParams.size());
HookParams arg = theHookParams.get(0);
assertEquals("Expected pointcut to be invoked with 1 argument", 1, arg.values().size());
return arg.values().iterator().next();
}
}

View File

@ -0,0 +1,91 @@
package ca.uhn.fhir.jpa.subscription.module.cache;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
public class ActiveSubscriptionCacheTest {
@Test
public void twoPhaseDelete() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
activeSubscriptionCache.put(id1, activeSub1);
assertFalse(activeSub1.isFlagForDeletion());
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertTrue(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNull(activeSubscriptionCache.get(id1));
}
@Test
public void secondPassUnflags() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.put(id1, activeSub1);
assertFalse(activeSub1.isFlagForDeletion());
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertTrue(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
saveIds.add(id1);
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertFalse(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
}
@Test
public void onlyFlaggedDeleted() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
ActiveSubscription activeSub2 = new ActiveSubscription(null, null);
String id2 = "id2";
activeSubscriptionCache.put(id1, activeSub1);
activeSubscriptionCache.put(id2, activeSub2);
activeSub1.setFlagForDeletion(true);
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNull(activeSubscriptionCache.get(id1));
assertNotNull(activeSubscriptionCache.get(id2));
assertTrue(activeSub2.isFlagForDeletion());
}
@Test
public void onListSavesAndUnmarksFlag() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
ActiveSubscription activeSub2 = new ActiveSubscription(null, null);
String id2 = "id2";
activeSubscriptionCache.put(id1, activeSub1);
activeSubscriptionCache.put(id2, activeSub2);
activeSub1.setFlagForDeletion(true);
List<String> saveIds = new ArrayList<>();
saveIds.add(id1);
saveIds.add(id2);
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNotNull(activeSubscriptionCache.get(id1));
assertFalse(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id2));
assertFalse(activeSub2.isFlagForDeletion());
}
}

View File

@ -5,6 +5,7 @@ import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import ca.uhn.fhir.jpa.model.interceptor.api.IInterceptorRegistry;
import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut;
import ca.uhn.fhir.jpa.subscription.module.BaseSubscriptionDstu3Test;
import ca.uhn.fhir.jpa.subscription.module.IPointcutLatch;
import ca.uhn.fhir.jpa.subscription.module.PointcutLatch;
import ca.uhn.fhir.jpa.subscription.module.ResourceModifiedMessage;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionChannelFactory;
@ -150,7 +151,7 @@ public abstract class BaseBlockingQueueSubscribableChannelDstu3Test extends Base
ourListenerServer.stop();
}
public static class ObservationListener implements IResourceProvider {
public static class ObservationListener implements IResourceProvider, IPointcutLatch {
private PointcutLatch updateLatch = new PointcutLatch("Observation Update");
@ -176,18 +177,17 @@ public abstract class BaseBlockingQueueSubscribableChannelDstu3Test extends Base
return new MethodOutcome(new IdType("Observation/1"), false);
}
public void setExpectedCount(int count) throws InterruptedException {
@Override
public void setExpectedCount(int count) {
updateLatch.setExpectedCount(count);
}
public void awaitExpected() throws InterruptedException {
updateLatch.awaitExpected();
}
public void expectNothing() {
updateLatch.expectNothing();
@Override
public List<HookParams> awaitExpected() throws InterruptedException {
return updateLatch.awaitExpected();
}
@Override
public void clear() { updateLatch.clear();}
}
}

View File

@ -75,7 +75,7 @@ public class SubscriptionCheckingSubscriberTest extends BaseBlockingQueueSubscri
ourObservationListener.setExpectedCount(0);
sendObservation(code, "SNOMED-CT");
ourObservationListener.expectNothing();
ourObservationListener.clear();
assertEquals(0, ourContentTypes.size());
}

View File

@ -71,7 +71,7 @@ public class SubscriptionMatchingSubscriberTest extends BaseBlockingQueueSubscri
ourObservationListener.setExpectedCount(0);
sendObservation(code, "SNOMED-CT");
ourObservationListener.expectNothing();
ourObservationListener.clear();
assertEquals(0, ourContentTypes.size());
}

View File

@ -0,0 +1,76 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscription;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscriptionChannelType;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry;
import org.hl7.fhir.r4.model.IdType;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringRunner;
import static org.junit.Assert.*;
import static org.mockito.Mockito.when;
@RunWith(SpringRunner.class)
@ContextConfiguration
public class WebsocketConnectionValidatorTest {
public static String RESTHOOK_SUBSCRIPTION_ID = "1";
public static String WEBSOCKET_SUBSCRIPTION_ID = "2";
public static String NON_EXISTENT_SUBSCRIPTION_ID = "3";
@Configuration
@ComponentScan("ca.uhn.fhir.jpa.subscription.module.subscriber.websocket")
public static class SpringConfig {
}
@MockBean
SubscriptionRegistry mySubscriptionRegistry;
@Autowired
WebsocketConnectionValidator myWebsocketConnectionValidator;
@Before
public void before() {
CanonicalSubscription resthookSubscription = new CanonicalSubscription();
resthookSubscription.setChannelType(CanonicalSubscriptionChannelType.RESTHOOK);
ActiveSubscription resthookActiveSubscription = new ActiveSubscription(resthookSubscription, null);
when(mySubscriptionRegistry.get(RESTHOOK_SUBSCRIPTION_ID)).thenReturn(resthookActiveSubscription);
CanonicalSubscription websocketSubscription = new CanonicalSubscription();
websocketSubscription.setChannelType(CanonicalSubscriptionChannelType.WEBSOCKET);
ActiveSubscription websocketActiveSubscription = new ActiveSubscription(websocketSubscription, null);
when(mySubscriptionRegistry.get(WEBSOCKET_SUBSCRIPTION_ID)).thenReturn(websocketActiveSubscription);
}
@Test
public void validateRequest() {
IdType idType;
WebsocketValidationResponse response;
idType = new IdType();
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Invalid bind request - No ID included: null", response.getMessage());
idType = new IdType(NON_EXISTENT_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Invalid bind request - Unknown subscription: Subscription/" + NON_EXISTENT_SUBSCRIPTION_ID, response.getMessage());
idType = new IdType(RESTHOOK_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Subscription Subscription/" + RESTHOOK_SUBSCRIPTION_ID + " is not a WEBSOCKET subscription", response.getMessage());
idType = new IdType(WEBSOCKET_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertTrue(response.isValid());
}
}