ARTEMIS-2607 interceptor returns false but processing continues

This commit is contained in:
Justin Bertram 2020-01-31 14:21:20 -06:00 committed by Clebert Suconic
parent fa6a008fa9
commit a8cf6b04b4
13 changed files with 289 additions and 55 deletions

View File

@ -294,11 +294,11 @@ public class AMQPConnectionCallback implements FailureListener, CloseListener {
return null;
}
public void invokeIncomingInterceptors(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
manager.invokeIncoming(message, connection);
public String invokeIncomingInterceptors(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return manager.invokeIncoming(message, connection);
}
public void invokeOutgoingInterceptors(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
manager.invokeOutgoing(message, connection);
public String invokeOutgoingInterceptors(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return manager.invokeOutgoing(message, connection);
}
}

View File

@ -507,7 +507,7 @@ public class AMQPSessionCallback implements SessionCallback {
final Receiver receiver,
final RoutingContext routingContext) throws Exception {
message.setConnectionID(receiver.getSession().getConnection().getRemoteContainer());
invokeIncoming((AMQPMessage) message, (ActiveMQProtonRemotingConnection) transportConnection.getProtocolConnection());
if (invokeIncoming((AMQPMessage) message, (ActiveMQProtonRemotingConnection) transportConnection.getProtocolConnection()) == null) {
serverSession.send(transaction, message, directDeliver, false, routingContext);
afterIO(new IOCallback() {
@ -537,6 +537,9 @@ public class AMQPSessionCallback implements SessionCallback {
});
}
});
} else {
rejectMessage(delivery, Symbol.valueOf("failed"), "Interceptor rejected message");
}
}
/** Will execute a Runnable on an Address when there's space in memory*/
@ -692,12 +695,12 @@ public class AMQPSessionCallback implements SessionCallback {
manager.getServer().getSecurityStore().check(address, checkType, session);
}
public void invokeIncoming(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
protonSPI.invokeIncomingInterceptors(message, connection);
public String invokeIncoming(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return protonSPI.invokeIncomingInterceptors(message, connection);
}
public void invokeOutgoing(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
protonSPI.invokeOutgoingInterceptors(message, connection);
public String invokeOutgoing(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return protonSPI.invokeOutgoingInterceptors(message, connection);
}
public void addProducer(ServerProducer serverProducer) {

View File

@ -294,12 +294,12 @@ public class ProtonProtocolManager extends AbstractProtocolManager<AMQPMessage,
return prefixes;
}
public void invokeIncoming(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
super.invokeInterceptors(this.incomingInterceptors, message, connection);
public String invokeIncoming(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return super.invokeInterceptors(this.incomingInterceptors, message, connection);
}
public void invokeOutgoing(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
super.invokeInterceptors(this.outgoingInterceptors, message, connection);
public String invokeOutgoing(AMQPMessage message, ActiveMQProtonRemotingConnection connection) {
return super.invokeInterceptors(this.outgoingInterceptors, message, connection);
}
public int getInitialRemoteMaxFrameSize() {

View File

@ -787,7 +787,9 @@ public class ProtonServerSenderContext extends ProtonInitializable implements Pr
}
AMQPMessage message = CoreAmqpConverter.checkAMQP(messageReference.getMessage());
sessionSPI.invokeOutgoing(message, (ActiveMQProtonRemotingConnection) sessionSPI.getTransportConnection().getProtocolConnection());
if (sessionSPI.invokeOutgoing(message, (ActiveMQProtonRemotingConnection) sessionSPI.getTransportConnection().getProtocolConnection()) != null) {
return;
}
// Let the Message decide how to present the message bytes
ReadableBuffer sendBuffer = message.getSendBuffer(messageReference.getDeliveryCount());

View File

@ -100,7 +100,11 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter {
MQTTUtil.logMessage(session.getState(), message, true);
this.protocolManager.invokeIncoming(message, this.connection);
if (this.protocolManager.invokeIncoming(message, this.connection) != null) {
log.debugf("Interceptor rejected MQTT message: %s", message);
disconnect(true);
return;
}
switch (message.fixedHeader().messageType()) {
case CONNECT:
@ -246,8 +250,10 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter {
}
private void sendToClient(MqttMessage message) {
if (this.protocolManager.invokeOutgoing(message, connection) != null) {
return;
}
MQTTUtil.logMessage(session.getSessionState(), message, false);
this.protocolManager.invokeOutgoing(message, connection);
ctx.write(message);
ctx.flush();
}

View File

@ -209,12 +209,12 @@ public class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQ
return websocketRegistryNames;
}
public void invokeIncoming(MqttMessage mqttMessage, MQTTConnection connection) {
super.invokeInterceptors(this.incomingInterceptors, mqttMessage, connection);
public String invokeIncoming(MqttMessage mqttMessage, MQTTConnection connection) {
return super.invokeInterceptors(this.incomingInterceptors, mqttMessage, connection);
}
public void invokeOutgoing(MqttMessage mqttMessage, MQTTConnection connection) {
super.invokeInterceptors(this.outgoingInterceptors, mqttMessage, connection);
public String invokeOutgoing(MqttMessage mqttMessage, MQTTConnection connection) {
return super.invokeInterceptors(this.outgoingInterceptors, mqttMessage, connection);
}
public boolean isClientConnected(String clientId, MQTTConnection connection) {

View File

@ -154,8 +154,10 @@ public class StompProtocolManager extends AbstractProtocolManager<StompFrame, St
}
try {
invokeInterceptors(this.incomingInterceptors, request, conn);
conn.logFrame(request, true);
if (invokeInterceptors(this.incomingInterceptors, request, conn) != null) {
return;
}
conn.handleFrame(request);
} finally {
server.getStorageManager().clearContext();
@ -187,7 +189,9 @@ public class StompProtocolManager extends AbstractProtocolManager<StompFrame, St
// Public --------------------------------------------------------
public boolean send(final StompConnection connection, final StompFrame frame) {
invokeInterceptors(this.outgoingInterceptors, frame, connection);
if (invokeInterceptors(this.outgoingInterceptors, frame, connection) != null) {
return false;
}
connection.logFrame(frame, false);
synchronized (connection) {

View File

@ -1967,7 +1967,7 @@ public interface ActiveMQServerLogger extends BasicLogger {
@LogMessage(level = Logger.Level.ERROR)
@Message(id = 224082, value = "Failed to invoke an interceptor", format = Message.Format.MESSAGE_FORMAT)
void failedToInvokeAninterceptor(@Cause Exception e);
void failedToInvokeAnInterceptor(@Cause Exception e);
@LogMessage(level = Logger.Level.ERROR)
@Message(id = 224083, value = "Failed to close context", format = Message.Format.MESSAGE_FORMAT)

View File

@ -32,18 +32,20 @@ public abstract class AbstractProtocolManager<P, I extends BaseInterceptor<P>, C
private final Map<SimpleString, RoutingType> prefixes = new HashMap<>();
protected void invokeInterceptors(final List<I> interceptors, final P message, final C connection) {
protected String invokeInterceptors(final List<I> interceptors, final P message, final C connection) {
if (interceptors != null && !interceptors.isEmpty()) {
for (I interceptor : interceptors) {
try {
if (!interceptor.intercept(message, connection)) {
break;
return interceptor.getClass().getName();
}
} catch (Exception e) {
ActiveMQServerLogger.LOGGER.failedToInvokeAninterceptor(e);
ActiveMQServerLogger.LOGGER.failedToInvokeAnInterceptor(e);
}
}
}
return null;
}
@Override

View File

@ -80,6 +80,73 @@ public class AmqpSendReceiveInterceptorTest extends AmqpClientTestSupport {
connection.close();
}
@Test(timeout = 60000)
public void testRejectMessageWithIncomingInterceptor() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
server.getRemotingService().addIncomingInterceptor(new AmqpInterceptor() {
@Override
public boolean intercept(AMQPMessage message, RemotingConnection connection) throws ActiveMQException {
latch.countDown();
return false;
}
});
AmqpClient client = createAmqpClient();
AmqpConnection connection = addConnection(client.connect());
AmqpSession session = connection.createSession();
AmqpSender sender = session.createSender(getTestName());
AmqpMessage message = new AmqpMessage();
message.setMessageId("msg" + 1);
message.setText("Test-Message");
try {
sender.send(message);
fail("Sending message should have thrown exception here.");
} catch (Exception e) {
assertEquals("Interceptor rejected message [condition = failed]", e.getMessage());
}
assertTrue(latch.await(5, TimeUnit.SECONDS));
AmqpReceiver receiver = session.createReceiver(getTestName());
receiver.flow(2);
AmqpMessage amqpMessage = receiver.receive(5, TimeUnit.SECONDS);
assertNull(amqpMessage);
sender.close();
receiver.close();
connection.close();
}
@Test(timeout = 60000)
public void testRejectMessageWithOutgoingInterceptor() throws Exception {
AmqpClient client = createAmqpClient();
AmqpConnection connection = addConnection(client.connect());
AmqpSession session = connection.createSession();
AmqpSender sender = session.createSender(getTestName());
AmqpMessage message = new AmqpMessage();
message.setMessageId("msg" + 1);
message.setText("Test-Message");
sender.send(message);
final CountDownLatch latch = new CountDownLatch(1);
server.getRemotingService().addOutgoingInterceptor(new AmqpInterceptor() {
@Override
public boolean intercept(AMQPMessage packet, RemotingConnection connection) throws ActiveMQException {
latch.countDown();
return false;
}
});
AmqpReceiver receiver = session.createReceiver(getTestName());
receiver.flow(2);
AmqpMessage amqpMessage = receiver.receive(5, TimeUnit.SECONDS);
assertNull(amqpMessage);
assertEquals(latch.getCount(), 0);
sender.close();
receiver.close();
connection.close();
}
private static final String ADDRESS = "address";
private static final String MESSAGE_ID = "messageId";
private static final String CORRELATION_ID = "correlationId";

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.tests.integration.mqtt.imported;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import org.apache.activemq.artemis.api.core.ActiveMQException;
import org.apache.activemq.artemis.core.protocol.mqtt.MQTTInterceptor;
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
public class MQTTRejectingInterceptorTest extends MQTTTestSupport {
@Rule
public ErrorCollector collector = new ErrorCollector();
@Test(timeout = 60000)
public void testRejectedMQTTMessage() throws Exception {
final String addressQueue = name.getMethodName();
final String msgText = "Test rejected message";
final MQTTClientProvider subscribeProvider = getMQTTClientProvider();
initializeConnection(subscribeProvider);
subscribeProvider.subscribe(addressQueue, AT_MOST_ONCE);
MQTTInterceptor incomingInterceptor = new MQTTInterceptor() {
@Override
public boolean intercept(MqttMessage packet, RemotingConnection connection) throws ActiveMQException {
System.out.println("incoming");
if (packet.getClass() == MqttPublishMessage.class) {
return false;
} else {
return true;
}
}
};
server.getRemotingService().addIncomingInterceptor(incomingInterceptor);
final MQTTClientProvider publishProvider = getMQTTClientProvider();
initializeConnection(publishProvider);
publishProvider.publish(addressQueue, msgText.getBytes(), AT_MOST_ONCE, false);
assertNull(subscribeProvider.receive(3000));
subscribeProvider.disconnect();
publishProvider.disconnect();
}
}

View File

@ -38,8 +38,8 @@ public class StompWithInterceptorsTest extends StompTestBase {
@Override
public List<String> getIncomingInterceptors() {
List<String> stompIncomingInterceptor = new ArrayList<>();
stompIncomingInterceptor.add("org.apache.activemq.artemis.tests.integration.stomp.StompWithInterceptorsTest$IncomingStompInterceptor");
stompIncomingInterceptor.add("org.apache.activemq.artemis.tests.integration.stomp.StompWithInterceptorsTest$CoreInterceptor");
stompIncomingInterceptor.add(IncomingStompInterceptor.class.getName());
stompIncomingInterceptor.add(CoreInterceptor.class.getName());
return stompIncomingInterceptor;
}
@ -47,7 +47,7 @@ public class StompWithInterceptorsTest extends StompTestBase {
@Override
public List<String> getOutgoingInterceptors() {
List<String> stompOutgoingInterceptor = new ArrayList<>();
stompOutgoingInterceptor.add("org.apache.activemq.artemis.tests.integration.stomp.StompWithInterceptorsTest$OutgoingStompInterceptor");
stompOutgoingInterceptor.add(OutgoingStompInterceptor.class.getName());
return stompOutgoingInterceptor;
}

View File

@ -0,0 +1,86 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.tests.integration.stomp;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.activemq.artemis.api.core.SimpleString;
import org.apache.activemq.artemis.core.protocol.stomp.Stomp;
import org.apache.activemq.artemis.core.protocol.stomp.StompFrame;
import org.apache.activemq.artemis.core.protocol.stomp.StompFrameInterceptor;
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.apache.activemq.artemis.tests.integration.stomp.util.ClientStompFrame;
import org.apache.activemq.artemis.tests.integration.stomp.util.StompClientConnection;
import org.apache.activemq.artemis.tests.integration.stomp.util.StompClientConnectionFactory;
import org.apache.activemq.artemis.tests.util.Wait;
import org.junit.Assert;
import org.junit.Test;
public class StompWithRejectingInterceptorTest extends StompTestBase {
@Override
public List<String> getIncomingInterceptors() {
List<String> stompIncomingInterceptor = new ArrayList<>();
stompIncomingInterceptor.add(IncomingStompFrameRejectInterceptor.class.getName());
return stompIncomingInterceptor;
}
@Test
public void stompFrameInterceptor() throws Exception {
IncomingStompFrameRejectInterceptor.interceptedFrames.clear();
StompClientConnection conn = StompClientConnectionFactory.createClientConnection(uri);
conn.connect(defUser, defPass);
ClientStompFrame frame = conn.createFrame("SEND");
frame.addHeader("destination", getQueuePrefix() + getQueueName());
frame.setBody("Hello World");
conn.sendFrame(frame);
conn.disconnect();
assertTrue(Wait.waitFor(() -> IncomingStompFrameRejectInterceptor.interceptedFrames.size() == 3, 2000, 50));
List<String> incomingCommands = new ArrayList<>(4);
incomingCommands.add("CONNECT");
incomingCommands.add("SEND");
incomingCommands.add("DISCONNECT");
for (int i = 0; i < IncomingStompFrameRejectInterceptor.interceptedFrames.size(); i++) {
Assert.assertEquals(incomingCommands.get(i), IncomingStompFrameRejectInterceptor.interceptedFrames.get(i).getCommand());
}
Wait.assertFalse(() -> server.locateQueue(SimpleString.toSimpleString(getQueuePrefix() + getQueueName())).getMessageCount() > 0, 1000, 100);
}
public static class IncomingStompFrameRejectInterceptor implements StompFrameInterceptor {
static List<StompFrame> interceptedFrames = Collections.synchronizedList(new ArrayList<>());
@Override
public boolean intercept(StompFrame stompFrame, RemotingConnection connection) {
interceptedFrames.add(stompFrame);
if (stompFrame.getCommand() == Stomp.Commands.SEND) {
return false;
}
return true;
}
}
}