ARTEMIS-2458 Fix AMQP Transaction Rollback Ordering by using a sorted add

This commit is contained in:
Clebert Suconic 2019-08-20 17:05:33 -04:00
parent 6fc11338e6
commit 61eb379741
17 changed files with 343 additions and 16 deletions

View File

@ -17,6 +17,7 @@
package org.apache.activemq.artemis.utils.collections;
import java.lang.reflect.Array;
import java.util.Comparator;
import java.util.NoSuchElementException;
import java.util.Objects;
@ -43,8 +44,15 @@ public class LinkedListImpl<E> implements LinkedList<E> {
private int nextIndex;
private final Comparator<E> comparator;
public LinkedListImpl() {
this(null);
}
public LinkedListImpl(Comparator<E> comparator) {
iters = createIteratorArray(INITIAL_ITERATOR_ARRAY_SIZE);
this.comparator = comparator;
}
@Override
@ -84,6 +92,60 @@ public class LinkedListImpl<E> implements LinkedList<E> {
}
}
public void addSorted(E e) {
if (comparator == null) {
throw new NullPointerException("comparator=null");
}
if (size == 0) {
addHead(e);
} else {
if (comparator.compare(head.next.val(), e) < 0) {
addHead(e);
return;
}
// in our usage, most of the times we will just add to the end
// as the QueueImpl cancellations in AMQP will return the buffer back to the queue, in the order they were consumed.
// There is an exception to that case, when there are more messages on the queue.
// This would be an optimization for our usage.
// avoiding scanning the entire List just to add at the end, so we compare the end first.
if (comparator.compare(tail.val(), e) >= 0) {
addTail(e);
return;
}
Node<E> fetching = head.next;
while (fetching.next != null) {
int compareNext = comparator.compare(fetching.next.val(), e);
if (compareNext <= 0) {
addAfter(fetching, e);
return;
}
fetching = fetching.next;
}
// this shouldn't happen as the tail was compared before iterating
// the only possibilities for this to happen are:
// - there is a bug on the comparator
// - This method is buggy
// - The list wasn't properly synchronized as this list does't support concurrent access
//
// Also I'm not bothering about creating a Logger ID for this, because the only reason for this code to exist
// is because my OCD level is not letting this out.
throw new IllegalStateException("Cannot find a suitable place for your element, There's a mismatch in the comparator or there was concurrent adccess on the queue");
}
}
private void addAfter(Node<E> node, E e) {
Node<E> newNode = Node.with(e);
Node<E> nextNode = node.next;
node.next = newNode;
newNode.prev = node;
newNode.next = nextNode;
nextNode.prev = newNode;
size++;
}
@Override
public E poll() {
Node<E> ret = head.next;

View File

@ -27,6 +27,8 @@ public interface PriorityLinkedList<T> {
void addTail(T t, int priority);
void addSorted(T t, int priority);
T poll();
void clear();

View File

@ -17,6 +17,7 @@
package org.apache.activemq.artemis.utils.collections;
import java.lang.reflect.Array;
import java.util.Comparator;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
@ -40,10 +41,15 @@ public class PriorityLinkedListImpl<T> implements PriorityLinkedList<T> {
private int lastPriority = -1;
public PriorityLinkedListImpl(final int priorities) {
this(priorities, null);
}
public PriorityLinkedListImpl(final int priorities, Comparator<T> comparator) {
levels = (LinkedListImpl<T>[]) Array.newInstance(LinkedListImpl.class, priorities);
for (int i = 0; i < priorities; i++) {
levels[i] = new LinkedListImpl<>();
levels[i] = new LinkedListImpl<>(comparator);
}
}
@ -80,6 +86,15 @@ public class PriorityLinkedListImpl<T> implements PriorityLinkedList<T> {
exclusiveIncrementSize(1);
}
@Override
public void addSorted(T t, int priority) {
checkHighest(priority);
levels[priority].addSorted(t);
exclusiveIncrementSize(1);
}
@Override
public T poll() {
T t = null;

View File

@ -405,7 +405,7 @@ public class AMQPSessionCallback implements SessionCallback {
public void cancel(Object brokerConsumer, Message message, boolean updateCounts) throws Exception {
OperationContext oldContext = recoverContext();
try {
((ServerConsumer) brokerConsumer).individualCancel(message.getMessageID(), updateCounts);
((ServerConsumer) brokerConsumer).individualCancel(message.getMessageID(), updateCounts, true);
((ServerConsumer) brokerConsumer).getQueue().forceDelivery();
} finally {
resetContext(oldContext);

View File

@ -131,7 +131,7 @@ public class MQTTPublishManager {
sendServerMessage(mqttid, message, deliveryCount, qos);
} else {
// Client must have disconnected and it's Subscription QoS cleared
consumer.individualCancel(message.getMessageID(), false);
consumer.individualCancel(message.getMessageID(), false, true);
}
}
}

View File

@ -172,7 +172,8 @@ public interface Queue extends Bindable,CriticalComponent {
void cancel(Transaction tx, MessageReference ref, boolean ignoreRedeliveryCheck);
void cancel(MessageReference reference, long timeBase) throws Exception;
/** @param sorted it should use the messageID as a reference to where to add it in the queue */
void cancel(MessageReference reference, long timeBase, boolean sorted) throws Exception;
void deliverAsync();

View File

@ -98,7 +98,7 @@ public interface ServerConsumer extends Consumer, ConsumerInfo {
void reject(long messageID) throws Exception;
void individualCancel(long messageID, boolean failed) throws Exception;
void individualCancel(long messageID, boolean failed, boolean sorted) throws Exception;
void forceDelivery(long sequence);

View File

@ -354,7 +354,7 @@ public class BridgeImpl implements Bridge, SessionFailureListener, SendAcknowled
refqueue = ref.getQueue();
try {
refqueue.cancel(ref, timeBase);
refqueue.cancel(ref, timeBase, false);
} catch (Exception e) {
// There isn't much we can do besides log an error
ActiveMQServerLogger.LOGGER.errorCancellingRefOnBridge(e, ref);

View File

@ -16,6 +16,7 @@
*/
package org.apache.activemq.artemis.core.server.impl;
import java.util.Comparator;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.Consumer;
@ -33,6 +34,28 @@ import org.apache.activemq.artemis.utils.collections.LinkedListImpl;
*/
public class MessageReferenceImpl extends LinkedListImpl.Node<MessageReferenceImpl> implements MessageReference, Runnable {
private static final MessageReferenceComparatorByID idComparator = new MessageReferenceComparatorByID();
public static Comparator<MessageReference> getIDComparator() {
return idComparator;
}
private static class MessageReferenceComparatorByID implements Comparator<MessageReference> {
@Override
public int compare(MessageReference o1, MessageReference o2) {
long value = o2.getMessage().getMessageID() - o1.getMessage().getMessageID();
if (value > 0) {
return 1;
} else if (value < 0) {
return -1;
} else {
return 0;
}
}
}
private static final AtomicIntegerFieldUpdater<MessageReferenceImpl> DELIVERY_COUNT_UPDATER = AtomicIntegerFieldUpdater
.newUpdater(MessageReferenceImpl.class, "deliveryCount");

View File

@ -179,7 +179,7 @@ public class QueueImpl extends CriticalComponentImpl implements Queue {
private final MpscUnboundedArrayQueue<MessageReference> intermediateMessageReferences = new MpscUnboundedArrayQueue<>(8192);
// This is where messages are stored
private final PriorityLinkedList<MessageReference> messageReferences = new PriorityLinkedListImpl<>(QueueImpl.NUM_PRIORITIES);
private final PriorityLinkedList<MessageReference> messageReferences = new PriorityLinkedListImpl<>(QueueImpl.NUM_PRIORITIES, MessageReferenceImpl.getIDComparator());
// The quantity of pagedReferences on messageReferences priority list
private final AtomicInteger pagedReferences = new AtomicInteger(0);
@ -1631,11 +1631,15 @@ public class QueueImpl extends CriticalComponentImpl implements Queue {
}
@Override
public synchronized void cancel(final MessageReference reference, final long timeBase) throws Exception {
public synchronized void cancel(final MessageReference reference, final long timeBase, boolean sorted) throws Exception {
Pair<Boolean, Boolean> redeliveryResult = checkRedelivery(reference, timeBase, false);
if (redeliveryResult.getA()) {
if (!scheduledDeliveryHandler.checkAndSchedule(reference, false)) {
internalAddHead(reference);
if (sorted) {
internalAddSorted(reference);
} else {
internalAddHead(reference);
}
}
resetAllIterators();
@ -2469,6 +2473,23 @@ public class QueueImpl extends CriticalComponentImpl implements Queue {
messageReferences.addHead(ref, priority);
}
/**
* The caller of this method requires synchronized on the queue.
* I'm not going to add synchronized to this method just for a precaution,
* as I'm not 100% sure this won't cause any extra runtime.
*
* @param ref
*/
private void internalAddSorted(final MessageReference ref) {
queueMemorySize.addAndGet(ref.getMessageMemoryEstimate());
pendingMetrics.incrementMetrics(ref);
refAdded(ref);
int priority = getPriority(ref);
messageReferences.addSorted(ref, priority);
}
private int getPriority(MessageReference ref) {
try {
return ref.getMessage().getPriority();

View File

@ -992,7 +992,7 @@ public class ServerConsumerImpl implements ServerConsumer, ReadyListener {
}
@Override
public synchronized void individualCancel(final long messageID, boolean failed) throws Exception {
public synchronized void individualCancel(final long messageID, boolean failed, boolean sorted) throws Exception {
if (browseOnly) {
return;
}
@ -1007,7 +1007,7 @@ public class ServerConsumerImpl implements ServerConsumer, ReadyListener {
ref.decrementDeliveryCount();
}
ref.getQueue().cancel(ref, System.currentTimeMillis());
ref.getQueue().cancel(ref, System.currentTimeMillis(), sorted);
}

View File

@ -1146,7 +1146,7 @@ public class ServerSessionImpl implements ServerSession, FailureListener {
ServerConsumer consumer = locateConsumer(consumerID);
if (consumer != null) {
consumer.individualCancel(messageID, failed);
consumer.individualCancel(messageID, failed, false);
}
}

View File

@ -1108,7 +1108,7 @@ public class ScheduledDeliveryHandlerTest extends Assert {
}
@Override
public void cancel(MessageReference reference, long timeBase) throws Exception {
public void cancel(MessageReference reference, long timeBase, boolean backInPlace) throws Exception {
}

View File

@ -150,7 +150,7 @@ public class DummyServerConsumer implements ServerConsumer {
}
@Override
public void individualCancel(long messageID, boolean failed) throws Exception {
public void individualCancel(long messageID, boolean failed, boolean sorted) throws Exception {
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.tests.integration.client;
import javax.jms.Connection;
import javax.jms.ConnectionFactory;
import javax.jms.Destination;
import javax.jms.Message;
import javax.jms.MessageConsumer;
import javax.jms.MessageProducer;
import javax.jms.Queue;
import javax.jms.Session;
import javax.jms.TextMessage;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.activemq.artemis.tests.util.JMSTestBase;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import static org.apache.activemq.artemis.tests.util.CFUtil.createConnectionFactory;
@RunWith(value = Parameterized.class)
public class JMSOrderTest extends JMSTestBase {
String protocol;
ConnectionFactory protocolCF;
public JMSOrderTest(String protocol) {
this.protocol = protocol;
}
@Before
public void setupCF() {
protocolCF = createConnectionFactory(protocol, "tcp://localhost:61616");
}
@Parameterized.Parameters(name = "protocol={0}")
public static Collection getParameters() {
return Arrays.asList(new Object[][]{{"AMQP"}, {"OPENWIRE"}, {"CORE"}});
}
protected void sendToAmqQueue(int count) throws Exception {
Connection activemqConnection = protocolCF.createConnection();
Session amqSession = activemqConnection.createSession(false, Session.AUTO_ACKNOWLEDGE);
Queue amqTestQueue = amqSession.createQueue(name.getMethodName());
sendMessages(activemqConnection, amqTestQueue, count);
activemqConnection.close();
}
public void sendMessages(Connection connection, Destination destination, int count) throws Exception {
Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageProducer p = session.createProducer(destination);
for (int i = 1; i <= count; i++) {
TextMessage message = session.createTextMessage();
message.setText("TextMessage: " + i);
message.setIntProperty("nr", i);
p.send(message);
}
session.close();
}
@Test(timeout = 60000)
public void testReceiveSomeThenRollback() throws Exception {
Connection connection = protocolCF.createConnection();
try {
connection.start();
int totalCount = 5;
int consumeBeforeRollback = 2;
sendToAmqQueue(totalCount);
Session session = connection.createSession(true, Session.SESSION_TRANSACTED);
Queue queue = session.createQueue(name.getMethodName());
MessageConsumer consumer = session.createConsumer(queue);
for (int i = 1; i <= consumeBeforeRollback; i++) {
Message message = consumer.receive(3000);
assertNotNull(message);
assertEquals("Unexpected message number", i, message.getIntProperty("nr"));
}
session.rollback();
// Consume again.. the previously consumed messages should get delivered
// again after the rollback and then the remainder should follow
List<Integer> messageNumbers = new ArrayList<>();
for (int i = 1; i <= totalCount; i++) {
Message message = consumer.receive(3000);
assertNotNull("Failed to receive message: " + i, message);
int msgNum = message.getIntProperty("nr");
System.out.println("Received " + msgNum);
messageNumbers.add(msgNum);
}
session.commit();
assertEquals("Unexpected size of list", totalCount, messageNumbers.size());
for (int i = 0; i < messageNumbers.size(); i++) {
assertEquals("Unexpected order of messages: " + messageNumbers, Integer.valueOf(i + 1), messageNumbers.get(i));
}
} finally {
connection.close();
}
}
}

View File

@ -351,7 +351,7 @@ public class FakeQueue extends CriticalComponentImpl implements Queue {
}
@Override
public void cancel(final MessageReference reference, final long timeBase) throws Exception {
public void cancel(final MessageReference reference, final long timeBase, boolean sorted) throws Exception {
// no-op
}

View File

@ -17,14 +17,18 @@
package org.apache.activemq.artemis.tests.unit.util;
import java.lang.ref.WeakReference;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.activemq.artemis.tests.util.ActiveMQTestBase;
import org.apache.activemq.artemis.tests.util.RandomUtil;
import org.apache.activemq.artemis.utils.collections.LinkedListImpl;
import org.apache.activemq.artemis.utils.collections.LinkedListIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@ -37,7 +41,74 @@ public class LinkedListTest extends ActiveMQTestBase {
public void setUp() throws Exception {
super.setUp();
list = new LinkedListImpl<>();
list = new LinkedListImpl<>(integerComparator);
}
Comparator<Integer> integerComparator = new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
if (o1.intValue() == o2.intValue()) {
return 0;
}
if (o2.intValue() > o1.intValue()) {
return 1;
} else {
return -1;
}
}
};
@Test
public void addSorted() {
list.addSorted(1);
list.addSorted(3);
list.addSorted(2);
list.addSorted(0);
validateOrder(null);
Assert.assertEquals(4, list.size());
}
@Test
public void randomSorted() {
HashSet<Integer> values = new HashSet<>();
for (int i = 0; i < 1000; i++) {
int value = RandomUtil.randomInt();
if (!values.contains(value)) {
values.add(value);
list.addSorted(value);
}
}
Assert.assertEquals(values.size(), list.size());
validateOrder(values);
Assert.assertEquals(0, values.size());
}
private void validateOrder(HashSet<Integer> values) {
Integer previous = null;
LinkedListIterator<Integer> integerIterator = list.iterator();
while (integerIterator.hasNext()) {
Integer value = integerIterator.next();
if (previous != null) {
Assert.assertTrue(value + " should be > " + previous, integerComparator.compare(previous, value) > 0);
Assert.assertTrue(value + " should be > " + previous, value.intValue() > previous.intValue());
}
if (values != null) {
values.remove(value);
}
previous = value;
}
integerIterator.close();
}
@Test