NIFI-11144 Fix failing tests for ConsumeJMS/PublishJMS

This closes #6930.

Signed-off-by: Tamas Palfy <tpalfy@apache.org>
This commit is contained in:
Nandor Soma Abonyi 2023-02-06 19:07:02 +01:00 committed by Tamas Palfy
parent 03625bb679
commit 9ee34eeb01
10 changed files with 191 additions and 67 deletions

View File

@ -206,8 +206,7 @@ public class StandardProcessorTestRunner implements TestRunner {
try {
ReflectionUtils.invokeMethodsWithAnnotation(OnScheduled.class, processor, context);
} catch (final Exception e) {
e.printStackTrace();
Assertions.fail("Could not invoke methods annotated with @OnScheduled annotation due to: " + e);
Assertions.fail("Could not invoke methods annotated with @OnScheduled annotation due to: " + e, e);
}
}

View File

@ -185,14 +185,12 @@ public abstract class AbstractJMSProcessor<T extends JMSWorker> extends Abstract
} catch (Exception e) {
getLogger().error("Failed to initialize JMS Connection Factory", e);
context.yield();
return;
throw e;
}
}
try {
rendezvousWithJms(context, session, worker);
} catch (Exception e) {
getLogger().error("Error while trying to process JMS message", e);
} finally {
//in case of exception during worker's connection (consumer or publisher),
//an appropriate service is responsible to invalidate the worker.
@ -209,7 +207,7 @@ public abstract class AbstractJMSProcessor<T extends JMSWorker> extends Abstract
CachingConnectionFactory currentCF = (CachingConnectionFactory)worker.jmsTemplate.getConnectionFactory();
connectionFactoryProvider.resetConnectionFactory(currentCF.getTargetConnectionFactory());
worker = buildTargetResource(context);
}catch(Exception e) {
} catch (Exception e) {
getLogger().error("Failed to rebuild: " + connectionFactoryProvider);
worker = null;
}

View File

@ -298,9 +298,10 @@ public class ConsumeJMS extends AbstractJMSProcessor<JMSConsumer> {
}
});
} catch(Exception e) {
getLogger().error("Error while trying to process JMS message", e);
consumer.setValid(false);
context.yield();
throw e; // for backward compatibility with exception handling in flows
throw e;
}
}

View File

@ -46,7 +46,7 @@ import java.util.Map;
/**
* Generic consumer of messages from JMS compliant messaging system.
*/
final class JMSConsumer extends JMSWorker {
class JMSConsumer extends JMSWorker {
JMSConsumer(CachingConnectionFactory connectionFactory, JmsTemplate jmsTemplate, ComponentLog logger) {
super(connectionFactory, jmsTemplate, logger);
@ -83,6 +83,9 @@ final class JMSConsumer extends JMSWorker {
}
/**
* Receives a message from the broker. It is the consumerCallback's responsibility to acknowledge the received message.
*/
public void consume(final String destinationName, String errorQueueName, final boolean durable, final boolean shared, final String subscriptionName, final String messageSelector,
final String charset, final ConsumerCallback consumerCallback) {
this.jmsTemplate.execute(new SessionCallback<Void>() {

View File

@ -222,7 +222,7 @@ public class PublishJMS extends AbstractJMSProcessor<JMSPublisher> {
processSession.getProvenanceReporter().send(flowFile, destinationName);
} catch (Exception e) {
processSession.transfer(flowFile, REL_FAILURE);
this.getLogger().error("Failed while sending message to JMS via " + publisher, e);
getLogger().error("Failed while sending message to JMS via " + publisher, e);
context.yield();
publisher.setValid(false);
}

View File

@ -23,6 +23,7 @@ import org.apache.activemq.command.ActiveMQMessage;
import org.apache.activemq.transport.tcp.TcpTransport;
import org.apache.activemq.transport.tcp.TcpTransportFactory;
import org.apache.activemq.wireformat.WireFormat;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.jms.cf.JMSConnectionFactoryProperties;
import org.apache.nifi.jms.cf.JMSConnectionFactoryProvider;
import org.apache.nifi.jms.cf.JMSConnectionFactoryProviderDefinition;
@ -30,6 +31,7 @@ import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.io.OutputStreamCallback;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.MockProcessContext;
@ -62,12 +64,15 @@ import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.apache.nifi.jms.processors.helpers.AssertionUtils.assertCausedBy;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
public class ConsumeJMSIT {
@ -196,9 +201,7 @@ public class ConsumeJMSIT {
try {
ActiveMQConnectionFactory cf = new ActiveMQConnectionFactory("vm://localhost?broker.persistent=false");
JMSPublisher sender = new JMSPublisher((CachingConnectionFactory) jmsTemplate.getConnectionFactory(), jmsTemplate, mock(ComponentLog.class));
sender.jmsTemplate.send("testMapMessage", __ -> createUnsupportedMessage("unsupportedMessagePropertyKey", "unsupportedMessagePropertyValue"));
jmsTemplate.send("testMapMessage", __ -> createUnsupportedMessage("unsupportedMessagePropertyKey", "unsupportedMessagePropertyValue"));
TestRunner runner = TestRunners.newTestRunner(new ConsumeJMS());
JMSConnectionFactoryProviderDefinition cs = mock(JMSConnectionFactoryProviderDefinition.class);
@ -226,9 +229,7 @@ public class ConsumeJMSIT {
private void testMessageTypeAttribute(String destinationName, final MessageCreator messageCreator, String expectedJmsMessageTypeAttribute) throws Exception {
JmsTemplate jmsTemplate = CommonTest.buildJmsTemplateForDestination(false);
try {
JMSPublisher sender = new JMSPublisher((CachingConnectionFactory) jmsTemplate.getConnectionFactory(), jmsTemplate, mock(ComponentLog.class));
sender.jmsTemplate.send(destinationName, messageCreator);
jmsTemplate.send(destinationName, messageCreator);
TestRunner runner = TestRunners.newTestRunner(new ConsumeJMS());
JMSConnectionFactoryProviderDefinition cs = mock(JMSConnectionFactoryProviderDefinition.class);
@ -324,7 +325,8 @@ public class ConsumeJMSIT {
TestRunner runner = createNonSharedDurableConsumer(cf, destinationName);
runner.setThreadCount(2);
final TestRunner temp = runner;
assertThrows(Throwable.class, () -> temp.run(1, true));
assertCausedBy(ProcessException.class, "Durable non shared subscriptions cannot work on multiple threads.", () -> temp.run(1, true));
runner = createNonSharedDurableConsumer(cf, destinationName);
// using one thread, it should not fail.
@ -334,7 +336,7 @@ public class ConsumeJMSIT {
/**
* <p>
* This test validates the connection resources are closed if the publisher is marked as invalid.
* This test validates the connection resources are closed if the consumer is marked as invalid.
* </p>
* <p>
* This tests validates the proper resources handling for TCP connections using ActiveMQ (the bug was discovered against ActiveMQ 5.x). In this test, using some ActiveMQ's classes is possible to
@ -356,7 +358,7 @@ public class ConsumeJMSIT {
BrokerService broker = new BrokerService();
try {
broker.setPersistent(false);
broker.setBrokerName("nifi7034publisher");
broker.setBrokerName("nifi7034consumer");
TransportConnector connector = broker.addConnector("tcp://127.0.0.1:0");
int port = connector.getServer().getSocketAddress().getPort();
broker.start();
@ -384,7 +386,9 @@ public class ConsumeJMSIT {
runner.setProperty(ConsumeJMS.DESTINATION, destinationName);
runner.setProperty(ConsumeJMS.DESTINATION_TYPE, ConsumeJMS.TOPIC);
assertThrows(AssertionError.class, () -> runner.run());
runner.run();
// since the worker is marked to invalid, we don't need to expect an exception here, because the worker recreation is handled automatically
assertFalse(tcpTransport.get().isConnected(), "It is expected transport be closed. ");
} finally {
if (broker != null) {
@ -409,18 +413,21 @@ public class ConsumeJMSIT {
runner.setProperty(ConsumeJMS.DESTINATION, "foo");
runner.setProperty(ConsumeJMS.DESTINATION_TYPE, ConsumeJMS.TOPIC);
assertThrows(AssertionError.class, () -> runner.run());
assertTrue(((MockProcessContext) runner.getProcessContext()).isYieldCalled(), "In case of an exception, the processor should be yielded.");
assertCausedBy(UnknownHostException.class, runner::run);
assertTrue(((MockProcessContext) runner.getProcessContext()).isYieldCalled(), "In case of an exception, the processor should be yielded.");
}
@Test
public void whenExceptionIsRaisedDuringConnectionFactoryInitializationTheProcessorShouldBeYielded() throws Exception {
final String nonExistentClassName = "DummyJMSConnectionFactoryClass";
TestRunner runner = TestRunners.newTestRunner(ConsumeJMS.class);
// using (non-JNDI) JMS Connection Factory via controller service
JMSConnectionFactoryProvider cfProvider = new JMSConnectionFactoryProvider();
runner.addControllerService("cfProvider", cfProvider);
runner.setProperty(cfProvider, JMSConnectionFactoryProperties.JMS_CONNECTION_FACTORY_IMPL, "DummyJMSConnectionFactoryClass");
runner.setProperty(cfProvider, JMSConnectionFactoryProperties.JMS_CONNECTION_FACTORY_IMPL, nonExistentClassName);
runner.setProperty(cfProvider, JMSConnectionFactoryProperties.JMS_BROKER_URI, "DummyBrokerUri");
runner.enableControllerService(cfProvider);
@ -428,10 +435,49 @@ public class ConsumeJMSIT {
runner.setProperty(ConsumeJMS.DESTINATION, "myTopic");
runner.setProperty(ConsumeJMS.DESTINATION_TYPE, ConsumeJMS.TOPIC);
assertThrows(AssertionError.class, () -> runner.run());
assertCausedBy(ClassNotFoundException.class, nonExistentClassName, runner::run);
assertTrue(((MockProcessContext) runner.getProcessContext()).isYieldCalled(), "In case of an exception, the processor should be yielded.");
}
@Test
@Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
public void whenExceptionIsRaisedInAcceptTheProcessorShouldYieldAndRollback() throws Exception {
final String destination = "testQueue";
final RuntimeException expectedException = new RuntimeException();
final ConsumeJMS processor = new ConsumeJMS() {
@Override
protected void rendezvousWithJms(ProcessContext context, ProcessSession processSession, JMSConsumer consumer) throws ProcessException {
ProcessSession spiedSession = spy(processSession);
doThrow(expectedException).when(spiedSession).write(any(FlowFile.class), any(OutputStreamCallback.class));
super.rendezvousWithJms(context, spiedSession, consumer);
}
};
JmsTemplate jmsTemplate = CommonTest.buildJmsTemplateForDestination(false);
try {
jmsTemplate.send(destination, session -> session.createTextMessage("msg"));
TestRunner runner = TestRunners.newTestRunner(processor);
JMSConnectionFactoryProviderDefinition cs = mock(JMSConnectionFactoryProviderDefinition.class);
when(cs.getIdentifier()).thenReturn("cfProvider");
when(cs.getConnectionFactory()).thenReturn(jmsTemplate.getConnectionFactory());
runner.addControllerService("cfProvider", cs);
runner.enableControllerService(cs);
runner.setProperty(PublishJMS.CF_SERVICE, "cfProvider");
runner.setProperty(ConsumeJMS.DESTINATION, destination);
runner.setProperty(ConsumeJMS.DESTINATION_TYPE, ConsumeJMS.QUEUE);
assertCausedBy(expectedException, () -> runner.run(1, false));
assertTrue(((MockProcessContext) runner.getProcessContext()).isYieldCalled(), "In case of an exception, the processor should be yielded.");
} finally {
((CachingConnectionFactory) jmsTemplate.getConnectionFactory()).destroy();
}
}
private static void publishAMessage(ActiveMQConnectionFactory cf, final String destinationName, String messageContent) throws JMSException {
// Publish a message.
try (Connection conn = cf.createConnection();

View File

@ -18,7 +18,6 @@ package org.apache.nifi.jms.processors;
import org.apache.activemq.ActiveMQConnectionFactory;
import org.apache.activemq.command.ActiveMQMessage;
import org.apache.nifi.logging.ComponentLog;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.jms.connection.CachingConnectionFactory;
@ -33,12 +32,10 @@ import javax.jms.Session;
import javax.jms.StreamMessage;
import javax.jms.TextMessage;
import static org.mockito.Mockito.mock;
@Disabled("Used for manual testing.")
public class ConsumeJMSManualTest {
@Test
public void testTextMessage() throws Exception {
public void testTextMessage() {
MessageCreator messageCreator = session -> {
TextMessage message = session.createTextMessage("textMessageContent");
@ -49,7 +46,7 @@ public class ConsumeJMSManualTest {
}
@Test
public void testBytesMessage() throws Exception {
public void testBytesMessage() {
MessageCreator messageCreator = session -> {
BytesMessage message = session.createBytesMessage();
@ -62,7 +59,7 @@ public class ConsumeJMSManualTest {
}
@Test
public void testObjectMessage() throws Exception {
public void testObjectMessage() {
MessageCreator messageCreator = session -> {
ObjectMessage message = session.createObjectMessage();
@ -75,7 +72,7 @@ public class ConsumeJMSManualTest {
}
@Test
public void testStreamMessage() throws Exception {
public void testStreamMessage() {
MessageCreator messageCreator = session -> {
StreamMessage message = session.createStreamMessage();
@ -98,7 +95,7 @@ public class ConsumeJMSManualTest {
}
@Test
public void testMapMessage() throws Exception {
public void testMapMessage() {
MessageCreator messageCreator = session -> {
MapMessage message = session.createMapMessage();
@ -121,14 +118,14 @@ public class ConsumeJMSManualTest {
}
@Test
public void testUnsupportedMessage() throws Exception {
public void testUnsupportedMessage() {
MessageCreator messageCreator = session -> new ActiveMQMessage();
send(messageCreator);
}
private void send(MessageCreator messageCreator) throws Exception {
final String destinationName = "TEST";
private void send(MessageCreator messageCreator) {
final String destinationName = "TEST";
ConnectionFactory activeMqConnectionFactory = new ActiveMQConnectionFactory("tcp://localhost:61616");
final ConnectionFactory connectionFactory = new CachingConnectionFactory(activeMqConnectionFactory);
@ -139,9 +136,7 @@ public class ConsumeJMSManualTest {
jmsTemplate.setReceiveTimeout(10L);
try {
JMSPublisher sender = new JMSPublisher((CachingConnectionFactory) jmsTemplate.getConnectionFactory(), jmsTemplate, mock(ComponentLog.class));
sender.jmsTemplate.send(destinationName, messageCreator);
jmsTemplate.send(destinationName, messageCreator);
} finally {
((CachingConnectionFactory) jmsTemplate.getConnectionFactory()).destroy();
}

View File

@ -58,7 +58,7 @@ import org.springframework.jms.support.JmsHeaders;
public class JMSPublisherConsumerIT {
@Test
public void testObjectMessage() throws Exception {
public void testObjectMessage() {
final String destinationName = "testObjectMessage";
MessageCreator messageCreator = session -> {
@ -136,7 +136,7 @@ public class JMSPublisherConsumerIT {
}
@Test
public void testMapMessage() throws Exception {
public void testMapMessage() {
final String destinationName = "testObjectMessage";
MessageCreator messageCreator = session -> {
@ -269,7 +269,7 @@ public class JMSPublisherConsumerIT {
* at which point this test will no be longer required.
*/
@Test
public void validateFailOnUnsupportedMessageType() throws Exception {
public void validateFailOnUnsupportedMessageType() {
final String destinationName = "validateFailOnUnsupportedMessageType";
JmsTemplate jmsTemplate = CommonTest.buildJmsTemplateForDestination(false);
@ -332,13 +332,17 @@ public class JMSPublisherConsumerIT {
@Test
@Timeout(value = 20000, unit = TimeUnit.MILLISECONDS)
public void testMultipleThreads() throws Exception {
final int threadCount = 4;
final int totalMessageCount = 1000;
final int messagesPerThreadCount = totalMessageCount / threadCount;
String destinationName = "testMultipleThreads";
JmsTemplate publishTemplate = CommonTest.buildJmsTemplateForDestination(false);
final CountDownLatch consumerTemplateCloseCount = new CountDownLatch(4);
final CountDownLatch consumerTemplateCloseCount = new CountDownLatch(threadCount);
try {
JMSPublisher publisher = new JMSPublisher((CachingConnectionFactory) publishTemplate.getConnectionFactory(), publishTemplate, mock(ComponentLog.class));
for (int i = 0; i < 4000; i++) {
for (int i = 0; i < totalMessageCount; i++) {
publisher.publish(destinationName, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
}
@ -359,7 +363,7 @@ public class JMSPublisherConsumerIT {
try {
JMSConsumer consumer = new JMSConsumer((CachingConnectionFactory) consumeTemplate.getConnectionFactory(), consumeTemplate, mock(ComponentLog.class));
for (int j = 0; j < 1000 && msgCount.get() < 4000; j++) {
for (int j = 0; j < messagesPerThreadCount && msgCount.get() < totalMessageCount; j++) {
consumer.consume(destinationName, null, false, false, null, null, "UTF-8", callback);
}
} finally {
@ -373,7 +377,7 @@ public class JMSPublisherConsumerIT {
}
int iterations = 0;
while (msgCount.get() < 4000) {
while (msgCount.get() < totalMessageCount) {
Thread.sleep(10L);
if (++iterations % 100 == 0) {
System.out.println(msgCount.get() + " messages received so far");
@ -389,7 +393,7 @@ public class JMSPublisherConsumerIT {
@Test
@Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
public void validateMessageRedeliveryWhenNotAcked() throws Exception {
public void validateMessageRedeliveryWhenNotAcked() {
String destinationName = "validateMessageRedeliveryWhenNotAcked";
JmsTemplate jmsTemplate = CommonTest.buildJmsTemplateForDestination(false);
try {
@ -426,6 +430,7 @@ public class JMSPublisherConsumerIT {
callbackInvoked.set(true);
assertEquals("1", new String(response.getMessageBody()));
acknowledge(response);
}
});
}
@ -467,6 +472,7 @@ public class JMSPublisherConsumerIT {
callbackInvoked.set(true);
assertEquals("2", new String(response.getMessageBody()));
acknowledge(response);
}
});
}
@ -478,6 +484,14 @@ public class JMSPublisherConsumerIT {
}
}
private void acknowledge(JMSResponse response) {
try {
response.acknowledge();
} catch (JMSException e) {
throw new IllegalStateException("Unable to acknowledge JMS message");
}
}
@Test
public void testMessageSelector() {
String destinationName = "testMessageSelector";

View File

@ -16,15 +16,6 @@
*/
package org.apache.nifi.jms.processors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.apache.activemq.ActiveMQConnectionFactory;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.TransportConnector;
@ -41,12 +32,17 @@ import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.MockProcessContext;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.springframework.jms.core.JmsTemplate;
import org.springframework.jms.support.JmsHeaders;
import javax.jms.BytesMessage;
import javax.jms.ConnectionFactory;
import javax.jms.Message;
import javax.jms.Queue;
import javax.jms.TextMessage;
import javax.net.SocketFactory;
import java.io.IOException;
import java.lang.reflect.Proxy;
import java.net.URI;
@ -56,12 +52,14 @@ import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.jms.BytesMessage;
import javax.jms.ConnectionFactory;
import javax.jms.Message;
import javax.jms.Queue;
import javax.jms.TextMessage;
import javax.net.SocketFactory;
import static org.apache.nifi.jms.processors.helpers.AssertionUtils.assertCausedBy;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class PublishJMSIT {
@ -525,19 +523,23 @@ public class PublishJMSIT {
}
@Test
public void whenExceptionIsRaisedDuringConnectionFactoryInitializationTheProcessorShouldBeYielded() throws Exception {
public void whenExceptionIsRaisedDuringConnectionFactoryInitializationTheProcessorShouldBeYielded() {
final String nonExistentClassName = "DummyInitialContextFactoryClass";
TestRunner runner = TestRunners.newTestRunner(PublishJMS.class);
// using JNDI JMS Connection Factory configured locally on the processor
runner.setProperty(JndiJmsConnectionFactoryProperties.JNDI_INITIAL_CONTEXT_FACTORY, "DummyInitialContextFactoryClass");
runner.setProperty(JndiJmsConnectionFactoryProperties.JNDI_INITIAL_CONTEXT_FACTORY, nonExistentClassName);
runner.setProperty(JndiJmsConnectionFactoryProperties.JNDI_PROVIDER_URL, "DummyProviderUrl");
runner.setProperty(JndiJmsConnectionFactoryProperties.JNDI_CONNECTION_FACTORY_NAME, "DummyConnectionFactoryName");
runner.setProperty(ConsumeJMS.DESTINATION, "myTopic");
runner.setProperty(ConsumeJMS.DESTINATION_TYPE, ConsumeJMS.TOPIC);
runner.setProperty(AbstractJMSProcessor.DESTINATION, "myTopic");
runner.setProperty(AbstractJMSProcessor.DESTINATION_TYPE, AbstractJMSProcessor.TOPIC);
runner.enqueue("message");
assertThrows(AssertionError.class, () -> runner.run());
assertCausedBy(ClassNotFoundException.class, nonExistentClassName, runner::run);
assertTrue(((MockProcessContext) runner.getProcessContext()).isYieldCalled(), "In case of an exception, the processor should be yielded.");
}
}

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.jms.processors.helpers;
import org.apache.commons.lang3.exception.ExceptionUtils;
import java.util.List;
import static org.junit.jupiter.api.Assertions.fail;
public class AssertionUtils {
public static <T extends Throwable> void assertCausedBy(Class<T> expectedType, Runnable runnable) {
assertCausedBy(expectedType, null, runnable);
}
public static <T extends Throwable> void assertCausedBy(Class<T> expectedType, String expectedMessage, Runnable runnable) {
try {
runnable.run();
fail(String.format("Expected an exception to be thrown with a cause of %s, but nothing was thrown.", expectedType.getCanonicalName()));
} catch (Throwable throwable) {
final List<Throwable> causes = ExceptionUtils.getThrowableList(throwable);
for (Throwable cause : causes) {
if (expectedType.isInstance(cause)) {
if (expectedMessage != null) {
if (cause.getMessage() != null && cause.getMessage().startsWith(expectedMessage)) {
return;
}
} else {
return;
}
}
}
fail(String.format("Exception is thrown but not found %s as a cause. Received exception is: %s", expectedType.getCanonicalName(), throwable), throwable);
}
}
public static void assertCausedBy(Throwable expectedException, Runnable runnable) {
try {
runnable.run();
fail(String.format("Expected an exception to be thrown with a cause of %s, but nothing was thrown.", expectedException));
} catch (Throwable throwable) {
final List<Throwable> causes = ExceptionUtils.getThrowableList(throwable);
for (Throwable cause : causes) {
if (cause.equals(expectedException)) {
return;
}
}
fail(String.format("Exception is thrown but not found %s as a cause. Received exception is: %s", expectedException, throwable), throwable);
}
}
}