Websocket improvements (#1272)

* subscription loader now only pulls active subscriptions
only allow websocket connections to subscriptions of type websocket

* Added a "flag for deletion" to ActiveSubscription in the SubscriptionRegistry to handle the race condition of a scheduled sync overlapping with a subscription creation.  We could have used a package-scoped semaphore or a pre-remove FHIR read, but this seemed like the safest, simplest and most performant way to handle it.

* ActiveSubscriptionCacheTest

* WebsocketConnectionValidatorTest

* fix compile error in jpa example
This commit is contained in:
Ken Stevens 2019-04-12 08:36:49 -04:00 committed by GitHub
parent ca8b6acdf9
commit bb98ded1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 385 additions and 72 deletions

View File

@ -183,8 +183,8 @@
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId> <artifactId>maven-compiler-plugin</artifactId>
<configuration> <configuration>
<source>1.6</source> <source>1.8</source>
<target>1.6</target> <target>1.8</target>
</configuration> </configuration>
</plugin> </plugin>

View File

@ -20,7 +20,7 @@ package ca.uhn.fhir.jpa.config;
* #L% * #L%
*/ */
import ca.uhn.fhir.jpa.subscription.module.subscriber.SubscriptionWebsocketHandler; import ca.uhn.fhir.jpa.subscription.module.subscriber.websocket.SubscriptionWebsocketHandler;
import org.springframework.beans.factory.annotation.Autowire; import org.springframework.beans.factory.annotation.Autowire;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;

View File

@ -4,7 +4,6 @@ import ca.uhn.fhir.jpa.util.CircularQueueCaptureQueriesListener;
import ca.uhn.fhir.rest.server.interceptor.RequestValidatingInterceptor; import ca.uhn.fhir.rest.server.interceptor.RequestValidatingInterceptor;
import ca.uhn.fhir.validation.ResultSeverityEnum; import ca.uhn.fhir.validation.ResultSeverityEnum;
import net.ttddyy.dsproxy.listener.SingleQueryCountHolder; import net.ttddyy.dsproxy.listener.SingleQueryCountHolder;
import net.ttddyy.dsproxy.listener.logging.SLF4JLogLevel;
import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder; import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder;
import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbcp2.BasicDataSource;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -103,7 +102,7 @@ public class TestR4Config extends BaseJavaConfigR4 {
DataSource dataSource = ProxyDataSourceBuilder DataSource dataSource = ProxyDataSourceBuilder
.create(retVal) .create(retVal)
.logQueryBySlf4j(SLF4JLogLevel.INFO, "SQL") // .logQueryBySlf4j(SLF4JLogLevel.INFO, "SQL")
// .logSlowQueryBySlf4j(10, TimeUnit.SECONDS) // .logSlowQueryBySlf4j(10, TimeUnit.SECONDS)
// .countQuery(new ThreadQueryCountHolder()) // .countQuery(new ThreadQueryCountHolder())
.beforeQuery(new BlockLargeNumbersOfParamsListener()) .beforeQuery(new BlockLargeNumbersOfParamsListener())

View File

@ -205,7 +205,7 @@ public abstract class BaseResourceProviderR4Test extends BaseJpaR4Test {
fail("Failed to init subscriptions"); fail("Failed to init subscriptions");
} }
try { try {
mySubscriptionLoader.syncSubscriptions(); mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
break; break;
} catch (ResourceVersionConflictException e) { } catch (ResourceVersionConflictException e) {
Thread.sleep(250); Thread.sleep(250);

View File

@ -57,7 +57,7 @@ public class RestHookActivatesPreExistingSubscriptionsR4Test extends BaseResourc
@Before @Before
public void beforeSetSubscriptionActivatingInterceptor() { public void beforeSetSubscriptionActivatingInterceptor() {
SubscriptionActivatingInterceptor.setWaitForSubscriptionActivationSynchronouslyForUnitTest(true); SubscriptionActivatingInterceptor.setWaitForSubscriptionActivationSynchronouslyForUnitTest(true);
mySubscriptionLoader.syncSubscriptions(); mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
} }
@ -109,7 +109,7 @@ public class RestHookActivatesPreExistingSubscriptionsR4Test extends BaseResourc
createSubscription(criteria2, payload, ourListenerServerBase); createSubscription(criteria2, payload, ourListenerServerBase);
mySubscriptionTestUtil.registerRestHookInterceptor(); mySubscriptionTestUtil.registerRestHookInterceptor();
mySubscriptionLoader.syncSubscriptions(); mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
sendObservation(code, "SNOMED-CT"); sendObservation(code, "SNOMED-CT");

View File

@ -35,6 +35,7 @@ public class RestHookTestR4Test extends BaseSubscriptionsR4Test {
@After @After
public void cleanupStoppableSubscriptionDeliveringRestHookSubscriber() { public void cleanupStoppableSubscriptionDeliveringRestHookSubscriber() {
ourLog.info("@After");
myStoppableSubscriptionDeliveringRestHookSubscriber.setCountDownLatch(null); myStoppableSubscriptionDeliveringRestHookSubscriber.setCountDownLatch(null);
myStoppableSubscriptionDeliveringRestHookSubscriber.unPause(); myStoppableSubscriptionDeliveringRestHookSubscriber.unPause();
} }

View File

@ -69,7 +69,7 @@ public class RestHookTestWithInterceptorRegisteredToDaoConfigDstu2Test extends B
ourCreatedObservations.clear(); ourCreatedObservations.clear();
ourUpdatedObservations.clear(); ourUpdatedObservations.clear();
mySubscriptionLoader.syncSubscriptions(); mySubscriptionLoader.doSyncSubscriptionsForUnitTest();
} }
private void waitForQueueToDrain() throws InterruptedException { private void waitForQueueToDrain() throws InterruptedException {

View File

@ -0,0 +1,14 @@
package ca.uhn.fhir.jpa.subscription.module;
import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import java.util.List;
public interface IPointcutLatch {
void clear();
void setExpectedCount(int count);
List<HookParams> awaitExpected() throws InterruptedException;
}

View File

@ -39,6 +39,7 @@ public class ActiveSubscription {
private CanonicalSubscription mySubscription; private CanonicalSubscription mySubscription;
private final SubscribableChannel mySubscribableChannel; private final SubscribableChannel mySubscribableChannel;
private final Collection<MessageHandler> myDeliveryHandlerSet = new HashSet<>(); private final Collection<MessageHandler> myDeliveryHandlerSet = new HashSet<>();
private boolean flagForDeletion;
public ActiveSubscription(CanonicalSubscription theSubscription, SubscribableChannel theSubscribableChannel) { public ActiveSubscription(CanonicalSubscription theSubscription, SubscribableChannel theSubscribableChannel) {
mySubscription = theSubscription; mySubscription = theSubscription;
@ -94,4 +95,12 @@ public class ActiveSubscription {
public void setSubscription(CanonicalSubscription theCanonicalizedSubscription) { public void setSubscription(CanonicalSubscription theCanonicalizedSubscription) {
mySubscription = theCanonicalizedSubscription; mySubscription = theCanonicalizedSubscription;
} }
public boolean isFlagForDeletion() {
return flagForDeletion;
}
public void setFlagForDeletion(boolean theFlagForDeletion) {
flagForDeletion = theFlagForDeletion;
}
} }

View File

@ -9,9 +9,9 @@ package ca.uhn.fhir.jpa.subscription.module.cache;
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -21,6 +21,8 @@ package ca.uhn.fhir.jpa.subscription.module.cache;
*/ */
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
@ -29,7 +31,7 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
class ActiveSubscriptionCache { class ActiveSubscriptionCache {
private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(ActiveSubscriptionCache.class); private static final Logger ourLog = LoggerFactory.getLogger(ActiveSubscriptionCache.class);
private final Map<String, ActiveSubscription> myCache = new ConcurrentHashMap<>(); private final Map<String, ActiveSubscription> myCache = new ConcurrentHashMap<>();
@ -63,9 +65,17 @@ class ActiveSubscriptionCache {
public void unregisterAllSubscriptionsNotInCollection(Collection<String> theAllIds) { public void unregisterAllSubscriptionsNotInCollection(Collection<String> theAllIds) {
for (String next : new ArrayList<>(myCache.keySet())) { for (String next : new ArrayList<>(myCache.keySet())) {
if (!theAllIds.contains(next)) { ActiveSubscription activeSubscription = myCache.get(next);
ourLog.info("Unregistering Subscription/{}", next); if (theAllIds.contains(next)) {
remove(next); // In case we got a false positive from a race condition on a previous sync, unset the flag.
activeSubscription.setFlagForDeletion(false);
} else {
if (activeSubscription.isFlagForDeletion()) {
ourLog.info("Unregistering Subscription/{}", next);
remove(next);
} else {
activeSubscription.setFlagForDeletion(true);
}
} }
} }
} }

View File

@ -74,7 +74,10 @@ public class SubscriptionLoader {
@VisibleForTesting @VisibleForTesting
public int doSyncSubscriptionsForUnitTest() { public int doSyncSubscriptionsForUnitTest() {
return doSyncSubscriptionsWithRetry(); // Two passes for delete flag to take effect
int first = doSyncSubscriptionsWithRetry();
int second = doSyncSubscriptionsWithRetry();
return first + second;
} }
synchronized int doSyncSubscriptionsWithRetry() { synchronized int doSyncSubscriptionsWithRetry() {
@ -87,6 +90,10 @@ public class SubscriptionLoader {
ourLog.debug("Starting sync subscriptions"); ourLog.debug("Starting sync subscriptions");
SearchParameterMap map = new SearchParameterMap(); SearchParameterMap map = new SearchParameterMap();
map.add(Subscription.SP_STATUS, new TokenOrListParam() map.add(Subscription.SP_STATUS, new TokenOrListParam()
// TODO KHS Ideally we should only be pulling ACTIVE subscriptions here, but this class is overloaded so that
// the @Scheduled task also activates requested subscriptions if their type was enabled after they were requested
// There should be a separate @Scheduled task that looks for requested subscriptions that need to be activated
// independent of the registry loading process.
.addOr(new TokenParam(null, Subscription.SubscriptionStatus.REQUESTED.toCode())) .addOr(new TokenParam(null, Subscription.SubscriptionStatus.REQUESTED.toCode()))
.addOr(new TokenParam(null, Subscription.SubscriptionStatus.ACTIVE.toCode()))); .addOr(new TokenParam(null, Subscription.SubscriptionStatus.ACTIVE.toCode())));
map.setLoadSynchronousUpTo(SubscriptionConstants.MAX_SUBSCRIPTION_RESULTS); map.setLoadSynchronousUpTo(SubscriptionConstants.MAX_SUBSCRIPTION_RESULTS);

View File

@ -101,6 +101,7 @@ public class SubscriptionRegistry {
deliveryHandler = Optional.empty(); deliveryHandler = Optional.empty();
} }
ourLog.info("Registering active subscription {}", theSubscription.getIdElement().toUnqualified().getValue());
ActiveSubscription activeSubscription = new ActiveSubscription(canonicalized, deliveryChannel); ActiveSubscription activeSubscription = new ActiveSubscription(canonicalized, deliveryChannel);
deliveryHandler.ifPresent(activeSubscription::register); deliveryHandler.ifPresent(activeSubscription::register);
@ -115,11 +116,16 @@ public class SubscriptionRegistry {
public void unregisterSubscription(IIdType theId) { public void unregisterSubscription(IIdType theId) {
Validate.notNull(theId); Validate.notNull(theId);
String subscriptionId = theId.getIdPart(); String subscriptionId = theId.getIdPart();
ourLog.info("Unregistering active subscription {}", theId.toUnqualified().getValue());
myActiveSubscriptionCache.remove(subscriptionId); myActiveSubscriptionCache.remove(subscriptionId);
} }
@PreDestroy @PreDestroy
public void unregisterAllSubscriptions() { public void unregisterAllSubscriptions() {
// Once to set flag
unregisterAllSubscriptionsNotInCollection(Collections.emptyList());
// Twice to remove
unregisterAllSubscriptionsNotInCollection(Collections.emptyList()); unregisterAllSubscriptionsNotInCollection(Collections.emptyList());
} }
@ -143,8 +149,6 @@ public class SubscriptionRegistry {
return true; return true;
} }
unregisterSubscription(theSubscription.getIdElement()); unregisterSubscription(theSubscription.getIdElement());
} else {
ourLog.info("Registering active subscription {}", theSubscription.getIdElement().toUnqualified().getValue());
} }
if (Subscription.SubscriptionStatus.ACTIVE.equals(newSubscription.getStatus())) { if (Subscription.SubscriptionStatus.ACTIVE.equals(newSubscription.getStatus())) {
registerSubscription(theSubscription.getIdElement(), theSubscription); registerSubscription(theSubscription.getIdElement(), theSubscription);

View File

@ -1,4 +1,4 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber; package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
/* /*
* #%L * #%L
@ -9,9 +9,9 @@ package ca.uhn.fhir.jpa.subscription.module.subscriber;
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -22,10 +22,11 @@ package ca.uhn.fhir.jpa.subscription.module.subscriber;
import ca.uhn.fhir.context.FhirContext; import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription; import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry; import ca.uhn.fhir.jpa.subscription.module.subscriber.ResourceDeliveryMessage;
import ca.uhn.fhir.rest.server.exceptions.ResourceNotFoundException;
import org.hl7.fhir.instance.model.api.IIdType; import org.hl7.fhir.instance.model.api.IIdType;
import org.hl7.fhir.r4.model.IdType; import org.hl7.fhir.r4.model.IdType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
@ -41,9 +42,9 @@ import javax.annotation.PreDestroy;
import java.io.IOException; import java.io.IOException;
public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler { public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler {
private static final org.slf4j.Logger ourLog = org.slf4j.LoggerFactory.getLogger(SubscriptionWebsocketHandler.class); private static Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
@Autowired @Autowired
protected SubscriptionRegistry mySubscriptionRegistry; protected WebsocketConnectionValidator myWebsocketConnectionValidator;
@Autowired @Autowired
private FhirContext myCtx; private FhirContext myCtx;
@ -160,34 +161,18 @@ public class SubscriptionWebsocketHandler extends TextWebSocketHandler implement
private IIdType bindSimple(WebSocketSession theSession, String theBindString) { private IIdType bindSimple(WebSocketSession theSession, String theBindString) {
IdType id = new IdType(theBindString); IdType id = new IdType(theBindString);
if (!id.hasIdPart() || !id.isIdPartValid()) { WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id);
if (!response.isValid()) {
try { try {
String message = "Invalid bind request - No ID included"; ourLog.warn(response.getMessage());
ourLog.warn(message); theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage()));
theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), message));
} catch (IOException e) { } catch (IOException e) {
handleFailure(e); handleFailure(e);
} }
return null; return null;
} }
if (id.hasResourceType() == false) { myState = new BoundStaticSubscipriptionState(theSession, response.getActiveSubscription());
id = id.withResourceType("Subscription");
}
try {
ActiveSubscription activeSubscription = mySubscriptionRegistry.get(id.getIdPart());
myState = new BoundStaticSubscipriptionState( theSession, activeSubscription);
} catch (ResourceNotFoundException e) {
try {
String message = "Invalid bind request - Unknown subscription: " + id.getValue();
ourLog.warn(message);
theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), message));
} catch (IOException e1) {
handleFailure(e);
}
return null;
}
return id; return id;
} }

View File

@ -0,0 +1,42 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscriptionChannelType;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry;
import com.sun.istack.NotNull;
import org.hl7.fhir.r4.model.IdType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class WebsocketConnectionValidator {
private static Logger ourLog = LoggerFactory.getLogger(WebsocketConnectionValidator.class);
@Autowired
SubscriptionRegistry mySubscriptionRegistry;
public WebsocketValidationResponse validate(@NotNull IdType id) {
if (!id.hasIdPart() || !id.isIdPartValid()) {
return WebsocketValidationResponse.INVALID_RESPONSE("Invalid bind request - No ID included: " + id.getValue());
}
if (!id.hasResourceType()) {
id = id.withResourceType("Subscription");
}
ActiveSubscription activeSubscription = mySubscriptionRegistry.get(id.getIdPart());
if (activeSubscription == null) {
return WebsocketValidationResponse.INVALID_RESPONSE("Invalid bind request - Unknown subscription: " + id.getValue());
}
if (activeSubscription.getSubscription().getChannelType() != CanonicalSubscriptionChannelType.WEBSOCKET) {
return WebsocketValidationResponse.INVALID_RESPONSE("Subscription " + id.getValue() + " is not a " + CanonicalSubscriptionChannelType.WEBSOCKET + " subscription");
}
return WebsocketValidationResponse.VALID_RESPONSE(activeSubscription);
}
}

View File

@ -0,0 +1,35 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
public class WebsocketValidationResponse {
private final boolean myValid;
private final String myMessage;
private final ActiveSubscription myActiveSubscription;
public static WebsocketValidationResponse INVALID_RESPONSE(String theMessage) {
return new WebsocketValidationResponse(false, theMessage, null);
}
public static WebsocketValidationResponse VALID_RESPONSE(ActiveSubscription theActiveSubscription) {
return new WebsocketValidationResponse(true, null, theActiveSubscription);
}
private WebsocketValidationResponse(boolean theValid, String theMessage, ActiveSubscription theActiveSubscription) {
myValid = theValid;
myMessage = theMessage;
myActiveSubscription = theActiveSubscription;
}
public boolean isValid() {
return myValid;
}
public String getMessage() {
return myMessage;
}
public ActiveSubscription getActiveSubscription() {
return myActiveSubscription;
}
}

View File

@ -1,8 +1,10 @@
package ca.uhn.fhir.jpa.subscription.module; package ca.uhn.fhir.jpa.subscription.module;
import ca.uhn.fhir.jpa.model.interceptor.api.HookParams; import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import ca.uhn.fhir.jpa.model.interceptor.api.IAnonymousInterceptor; import ca.uhn.fhir.jpa.model.interceptor.api.IAnonymousInterceptor;
import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut; import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -13,9 +15,11 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class PointcutLatch implements IAnonymousInterceptor { // TODO KHS copy this version over to hapi-fhir
public class PointcutLatch implements IAnonymousInterceptor, IPointcutLatch {
private static final Logger ourLog = LoggerFactory.getLogger(PointcutLatch.class); private static final Logger ourLog = LoggerFactory.getLogger(PointcutLatch.class);
private static final int DEFAULT_TIMEOUT_SECONDS = 10; private static final int DEFAULT_TIMEOUT_SECONDS = 10;
private static final FhirObjectPrinter ourFhirObjectToStringMapper = new FhirObjectPrinter(); private static final FhirObjectPrinter ourFhirObjectToStringMapper = new FhirObjectPrinter();
@ -23,8 +27,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
private final String name; private final String name;
private CountDownLatch myCountdownLatch; private CountDownLatch myCountdownLatch;
private AtomicReference<String> myFailure; private AtomicReference<List<String>> myFailures;
private AtomicReference<List<HookParams>> myCalledWith; private AtomicReference<List<HookParams>> myCalledWith;
private int myInitialCount;
private Pointcut myPointcut; private Pointcut myPointcut;
public PointcutLatch(Pointcut thePointcut) { public PointcutLatch(Pointcut thePointcut) {
@ -36,6 +41,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
this.name = theName; this.name = theName;
} }
@Override
public void setExpectedCount(int count) { public void setExpectedCount(int count) {
if (myCountdownLatch != null) { if (myCountdownLatch != null) {
throw new PointcutLatchException("setExpectedCount() called before previous awaitExpected() completed."); throw new PointcutLatchException("setExpectedCount() called before previous awaitExpected() completed.");
@ -45,14 +51,15 @@ public class PointcutLatch implements IAnonymousInterceptor {
} }
private void createLatch(int count) { private void createLatch(int count) {
myFailure = new AtomicReference<>(); myFailures = new AtomicReference<>(new ArrayList<>());
myCalledWith = new AtomicReference<>(new ArrayList<>()); myCalledWith = new AtomicReference<>(new ArrayList<>());
myCountdownLatch = new CountDownLatch(count); myCountdownLatch = new CountDownLatch(count);
myInitialCount = count;
} }
private void setFailure(String failure) { private void addFailure(String failure) {
if (myFailure != null) { if (myFailures != null) {
myFailure.set(failure); myFailures.get().add(failure);
} else { } else {
throw new PointcutLatchException("trying to set failure on latch that hasn't been created: " + failure); throw new PointcutLatchException("trying to set failure on latch that hasn't been created: " + failure);
} }
@ -62,6 +69,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
return name + " " + this.getClass().getSimpleName(); return name + " " + this.getClass().getSimpleName();
} }
@Override
public List<HookParams> awaitExpected() throws InterruptedException { public List<HookParams> awaitExpected() throws InterruptedException {
return awaitExpectedWithTimeout(DEFAULT_TIMEOUT_SECONDS); return awaitExpectedWithTimeout(DEFAULT_TIMEOUT_SECONDS);
} }
@ -70,10 +78,19 @@ public class PointcutLatch implements IAnonymousInterceptor {
List<HookParams> retval = myCalledWith.get(); List<HookParams> retval = myCalledWith.get();
try { try {
assertNotNull(getName() + " awaitExpected() called before setExpected() called.", myCountdownLatch); assertNotNull(getName() + " awaitExpected() called before setExpected() called.", myCountdownLatch);
assertTrue(getName() + " timed out waiting " + timeoutSecond + " seconds for latch to be triggered.", myCountdownLatch.await(timeoutSecond, TimeUnit.SECONDS)); if (!myCountdownLatch.await(timeoutSecond, TimeUnit.SECONDS)) {
throw new AssertionError(getName() + " timed out waiting " + timeoutSecond + " seconds for latch to countdown from " + myInitialCount + " to 0. Is " + myCountdownLatch.getCount() + ".");
}
if (myFailure.get() != null) { List<String> failures = myFailures.get();
String error = getName() + ": " + myFailure.get(); String error = getName();
if (failures != null && failures.size() > 0) {
if (failures.size() > 1) {
error += " ERRORS: \n";
} else {
error += " ERROR: ";
}
error += failures.stream().collect(Collectors.joining("\n"));
error += "\nLatch called with values: " + myCalledWithString(); error += "\nLatch called with values: " + myCalledWithString();
throw new AssertionError(error); throw new AssertionError(error);
} }
@ -84,10 +101,7 @@ public class PointcutLatch implements IAnonymousInterceptor {
return retval; return retval;
} }
public void expectNothing() { @Override
clear();
}
public void clear() { public void clear() {
myCountdownLatch = null; myCountdownLatch = null;
} }
@ -109,9 +123,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
@Override @Override
public void invoke(Pointcut thePointcut, HookParams theArgs) { public void invoke(Pointcut thePointcut, HookParams theArgs) {
if (myCountdownLatch == null) { if (myCountdownLatch == null) {
throw new PointcutLatchException("invoke() called before setExpectedCount() called.", theArgs); throw new PointcutLatchException("invoke() called outside of setExpectedCount() .. awaitExpected(). Probably got more invocations than expected or clear() was called before invoke() arrived.", theArgs);
} else if (myCountdownLatch.getCount() <= 0) { } else if (myCountdownLatch.getCount() <= 0) {
setFailure("invoke() called " + (1 - myCountdownLatch.getCount()) + " more times than expected."); addFailure("invoke() called when countdown was zero.");
} }
if (myCalledWith.get() != null) { if (myCalledWith.get() != null) {
@ -119,6 +133,9 @@ public class PointcutLatch implements IAnonymousInterceptor {
} }
ourLog.info("Called {} {} with {}", name, myCountdownLatch, hookParamsToString(theArgs)); ourLog.info("Called {} {} with {}", name, myCountdownLatch, hookParamsToString(theArgs));
if (myCountdownLatch == null) {
throw new PointcutLatchException("invoke() called outside of setExpectedCount() .. awaitExpected(). Probably got more invocations than expected or clear() was called before invoke() arrived.", theArgs);
}
myCountdownLatch.countDown(); myCountdownLatch.countDown();
} }
@ -139,4 +156,27 @@ public class PointcutLatch implements IAnonymousInterceptor {
private static String hookParamsToString(HookParams hookParams) { private static String hookParamsToString(HookParams hookParams) {
return hookParams.values().stream().map(ourFhirObjectToStringMapper).collect(Collectors.joining(", ")); return hookParams.values().stream().map(ourFhirObjectToStringMapper).collect(Collectors.joining(", "));
} }
@Override
public String toString() {
return new ToStringBuilder(this)
.append("name", name)
.append("myCountdownLatch", myCountdownLatch)
// .append("myFailures", myFailures)
// .append("myCalledWith", myCalledWith)
.append("myInitialCount", myInitialCount)
.toString();
}
public Object getLatchInvocationParameter() {
return getLatchInvocationParameter(myCalledWith.get());
}
public static Object getLatchInvocationParameter(List<HookParams> theHookParams) {
assertNotNull(theHookParams);
assertEquals("Expected Pointcut to be invoked 1 time", 1, theHookParams.size());
HookParams arg = theHookParams.get(0);
assertEquals("Expected pointcut to be invoked with 1 argument", 1, arg.values().size());
return arg.values().iterator().next();
}
} }

View File

@ -0,0 +1,91 @@
package ca.uhn.fhir.jpa.subscription.module.cache;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
public class ActiveSubscriptionCacheTest {
@Test
public void twoPhaseDelete() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
activeSubscriptionCache.put(id1, activeSub1);
assertFalse(activeSub1.isFlagForDeletion());
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertTrue(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNull(activeSubscriptionCache.get(id1));
}
@Test
public void secondPassUnflags() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.put(id1, activeSub1);
assertFalse(activeSub1.isFlagForDeletion());
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertTrue(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
saveIds.add(id1);
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertFalse(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id1));
}
@Test
public void onlyFlaggedDeleted() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
ActiveSubscription activeSub2 = new ActiveSubscription(null, null);
String id2 = "id2";
activeSubscriptionCache.put(id1, activeSub1);
activeSubscriptionCache.put(id2, activeSub2);
activeSub1.setFlagForDeletion(true);
List<String> saveIds = new ArrayList<>();
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNull(activeSubscriptionCache.get(id1));
assertNotNull(activeSubscriptionCache.get(id2));
assertTrue(activeSub2.isFlagForDeletion());
}
@Test
public void onListSavesAndUnmarksFlag() {
ActiveSubscriptionCache activeSubscriptionCache = new ActiveSubscriptionCache();
ActiveSubscription activeSub1 = new ActiveSubscription(null, null);
String id1 = "id1";
ActiveSubscription activeSub2 = new ActiveSubscription(null, null);
String id2 = "id2";
activeSubscriptionCache.put(id1, activeSub1);
activeSubscriptionCache.put(id2, activeSub2);
activeSub1.setFlagForDeletion(true);
List<String> saveIds = new ArrayList<>();
saveIds.add(id1);
saveIds.add(id2);
activeSubscriptionCache.unregisterAllSubscriptionsNotInCollection(saveIds);
assertNotNull(activeSubscriptionCache.get(id1));
assertFalse(activeSub1.isFlagForDeletion());
assertNotNull(activeSubscriptionCache.get(id2));
assertFalse(activeSub2.isFlagForDeletion());
}
}

View File

@ -5,6 +5,7 @@ import ca.uhn.fhir.jpa.model.interceptor.api.HookParams;
import ca.uhn.fhir.jpa.model.interceptor.api.IInterceptorRegistry; import ca.uhn.fhir.jpa.model.interceptor.api.IInterceptorRegistry;
import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut; import ca.uhn.fhir.jpa.model.interceptor.api.Pointcut;
import ca.uhn.fhir.jpa.subscription.module.BaseSubscriptionDstu3Test; import ca.uhn.fhir.jpa.subscription.module.BaseSubscriptionDstu3Test;
import ca.uhn.fhir.jpa.subscription.module.IPointcutLatch;
import ca.uhn.fhir.jpa.subscription.module.PointcutLatch; import ca.uhn.fhir.jpa.subscription.module.PointcutLatch;
import ca.uhn.fhir.jpa.subscription.module.ResourceModifiedMessage; import ca.uhn.fhir.jpa.subscription.module.ResourceModifiedMessage;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionChannelFactory; import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionChannelFactory;
@ -150,7 +151,7 @@ public abstract class BaseBlockingQueueSubscribableChannelDstu3Test extends Base
ourListenerServer.stop(); ourListenerServer.stop();
} }
public static class ObservationListener implements IResourceProvider { public static class ObservationListener implements IResourceProvider, IPointcutLatch {
private PointcutLatch updateLatch = new PointcutLatch("Observation Update"); private PointcutLatch updateLatch = new PointcutLatch("Observation Update");
@ -176,18 +177,17 @@ public abstract class BaseBlockingQueueSubscribableChannelDstu3Test extends Base
return new MethodOutcome(new IdType("Observation/1"), false); return new MethodOutcome(new IdType("Observation/1"), false);
} }
public void setExpectedCount(int count) throws InterruptedException { @Override
public void setExpectedCount(int count) {
updateLatch.setExpectedCount(count); updateLatch.setExpectedCount(count);
} }
public void awaitExpected() throws InterruptedException { @Override
updateLatch.awaitExpected(); public List<HookParams> awaitExpected() throws InterruptedException {
} return updateLatch.awaitExpected();
public void expectNothing() {
updateLatch.expectNothing();
} }
@Override
public void clear() { updateLatch.clear();} public void clear() { updateLatch.clear();}
} }
} }

View File

@ -75,7 +75,7 @@ public class SubscriptionCheckingSubscriberTest extends BaseBlockingQueueSubscri
ourObservationListener.setExpectedCount(0); ourObservationListener.setExpectedCount(0);
sendObservation(code, "SNOMED-CT"); sendObservation(code, "SNOMED-CT");
ourObservationListener.expectNothing(); ourObservationListener.clear();
assertEquals(0, ourContentTypes.size()); assertEquals(0, ourContentTypes.size());
} }

View File

@ -71,7 +71,7 @@ public class SubscriptionMatchingSubscriberTest extends BaseBlockingQueueSubscri
ourObservationListener.setExpectedCount(0); ourObservationListener.setExpectedCount(0);
sendObservation(code, "SNOMED-CT"); sendObservation(code, "SNOMED-CT");
ourObservationListener.expectNothing(); ourObservationListener.clear();
assertEquals(0, ourContentTypes.size()); assertEquals(0, ourContentTypes.size());
} }

View File

@ -0,0 +1,76 @@
package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscription;
import ca.uhn.fhir.jpa.subscription.module.CanonicalSubscriptionChannelType;
import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
import ca.uhn.fhir.jpa.subscription.module.cache.SubscriptionRegistry;
import org.hl7.fhir.r4.model.IdType;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringRunner;
import static org.junit.Assert.*;
import static org.mockito.Mockito.when;
@RunWith(SpringRunner.class)
@ContextConfiguration
public class WebsocketConnectionValidatorTest {
public static String RESTHOOK_SUBSCRIPTION_ID = "1";
public static String WEBSOCKET_SUBSCRIPTION_ID = "2";
public static String NON_EXISTENT_SUBSCRIPTION_ID = "3";
@Configuration
@ComponentScan("ca.uhn.fhir.jpa.subscription.module.subscriber.websocket")
public static class SpringConfig {
}
@MockBean
SubscriptionRegistry mySubscriptionRegistry;
@Autowired
WebsocketConnectionValidator myWebsocketConnectionValidator;
@Before
public void before() {
CanonicalSubscription resthookSubscription = new CanonicalSubscription();
resthookSubscription.setChannelType(CanonicalSubscriptionChannelType.RESTHOOK);
ActiveSubscription resthookActiveSubscription = new ActiveSubscription(resthookSubscription, null);
when(mySubscriptionRegistry.get(RESTHOOK_SUBSCRIPTION_ID)).thenReturn(resthookActiveSubscription);
CanonicalSubscription websocketSubscription = new CanonicalSubscription();
websocketSubscription.setChannelType(CanonicalSubscriptionChannelType.WEBSOCKET);
ActiveSubscription websocketActiveSubscription = new ActiveSubscription(websocketSubscription, null);
when(mySubscriptionRegistry.get(WEBSOCKET_SUBSCRIPTION_ID)).thenReturn(websocketActiveSubscription);
}
@Test
public void validateRequest() {
IdType idType;
WebsocketValidationResponse response;
idType = new IdType();
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Invalid bind request - No ID included: null", response.getMessage());
idType = new IdType(NON_EXISTENT_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Invalid bind request - Unknown subscription: Subscription/" + NON_EXISTENT_SUBSCRIPTION_ID, response.getMessage());
idType = new IdType(RESTHOOK_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertFalse(response.isValid());
assertEquals("Subscription Subscription/" + RESTHOOK_SUBSCRIPTION_ID + " is not a WEBSOCKET subscription", response.getMessage());
idType = new IdType(WEBSOCKET_SUBSCRIPTION_ID);
response = myWebsocketConnectionValidator.validate(idType);
assertTrue(response.isValid());
}
}