AMQ-9658 - Properly increment transport receive counter (#1389)

Switch to using an AtomicInteger for tracking bytes received in a
TcpTransport. This makes incrementing the counter an atomic operation.
Previously a volatile int was used and incrementing volatiles is not
atomic because it's a 3 step process of read, update, set.

This also makes a small fix to ensure that the full initialization
buffer will always be entirely read and processed when using
the auto+nio+ssl transport. Previous the code assumed only the first
command was stored in the initialization buffer but technically more
bytes could exist for a future command (even if unlikely with the
current Java implementation).
This commit is contained in:
Christopher L. Shannon 2025-02-05 10:02:28 -05:00 committed by GitHub
parent fb52acb844
commit d9e89f4b5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 41 additions and 39 deletions

View File

@ -140,7 +140,7 @@ public class AmqpNioTransport extends TcpTransport {
}
protected void processBuffer(ByteBuffer buffer, int readSize) throws Exception {
receiveCounter += readSize;
receiveCounter.addAndGet(readSize);
buffer.flip();
frameReader.parse(buffer);
@ -164,4 +164,4 @@ public class AmqpNioTransport extends TcpTransport {
super.doStop(stopper);
}
}
}
}

View File

@ -290,7 +290,7 @@ public class AutoTcpTransportServer extends TcpTransportServer {
try {
//If this fails and throws an exception and the socket will be closed
waitForProtocolDetectionFinish(future, readBytes);
waitForProtocolDetectionFinish(future, readBytes.get());
} finally {
//call cancel in case task didn't complete
future.cancel(true);
@ -311,7 +311,7 @@ public class AutoTcpTransportServer extends TcpTransportServer {
return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory);
}
protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception {
protected void waitForProtocolDetectionFinish(final Future<?> future, final int readBytes) throws Exception {
try {
//Wait for protocolDetectionTimeOut if defined
if (protocolDetectionTimeOut > 0) {
@ -321,7 +321,7 @@ public class AutoTcpTransportServer extends TcpTransportServer {
}
} catch (TimeoutException e) {
throw new InactivityIOException("Client timed out before wire format could be detected. " +
" 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent.");
" 8 bytes are required to detect the protocol but only: " + readBytes + " byte(s) were sent.");
}
}

View File

@ -145,20 +145,20 @@ public class AutoNIOSSLTransportServer extends AutoTcpTransportServer {
//to be told when bytes are ready
in.serviceRead();
attempts++;
} while(in.getReadSize().get() < 8 && !Thread.interrupted());
} while(in.getReceiveCounter() < 8 && !Thread.interrupted());
}
});
try {
//If this fails and throws an exception and the socket will be closed
waitForProtocolDetectionFinish(future, in.getReadSize());
waitForProtocolDetectionFinish(future, in.getReceiveCounter());
} finally {
//call cancel in case task didn't complete which will interrupt the task
future.cancel(true);
}
in.stop();
InitBuffer initBuffer = new InitBuffer(in.getReadSize().get(), ByteBuffer.allocate(in.getReadData().length));
InitBuffer initBuffer = new InitBuffer(in.getReceiveCounter(), ByteBuffer.allocate(in.getReadData().length));
initBuffer.buffer.put(in.getReadData());
ProtocolInfo protocolInfo = detectProtocol(in.getReadData());

View File

@ -148,16 +148,10 @@ public class AutoInitNioSSLTransport extends NIOSSLTransport {
private volatile byte[] readData;
private final AtomicInteger readSize = new AtomicInteger();
public byte[] getReadData() {
return readData != null ? readData : new byte[0];
}
public AtomicInteger getReadSize() {
return readSize;
}
@Override
public void serviceRead() {
try {
@ -187,14 +181,13 @@ public class AutoInitNioSSLTransport extends NIOSSLTransport {
break;
}
receiveCounter += readCount;
readSize.addAndGet(readCount);
receiveCounter.addAndGet(readCount);
}
if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
processCommand(plain);
//we have received enough bytes to detect the protocol
if (receiveCounter >= 8) {
if (receiveCounter.get() >= 8) {
break;
}
}
@ -208,7 +201,7 @@ public class AutoInitNioSSLTransport extends NIOSSLTransport {
@Override
protected void processCommand(ByteBuffer plain) throws Exception {
ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter.get());
if (readData != null) {
newBuffer.put(readData);
}

View File

@ -215,19 +215,27 @@ public class NIOSSLTransport extends NIOTransport {
}
//Only used for the auto transport to abort the openwire init method early if already initialized
boolean openWireInititialized = false;
boolean openWireInitialized = false;
protected void doOpenWireInit() throws Exception {
//Do this later to let wire format negotiation happen
if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
if (initBuffer != null && !openWireInitialized && this.wireFormat instanceof OpenWireFormat) {
initBuffer.buffer.flip();
if (initBuffer.buffer.hasRemaining()) {
nextFrameSize = -1;
receiveCounter += initBuffer.readSize;
processCommand(initBuffer.buffer);
processCommand(initBuffer.buffer);
receiveCounter.addAndGet(initBuffer.readSize);
do {
// This should almost always just be called 2 times, the first call reads
// the size and allocates space for the frame. The second call reads
// in the frame to process. This is enough to read in the initial WireFormatInfo
// frame that will be sent. However, it's technically possible for
// there to be extra data after that if more bytes came in during the initial
// socket read if a client sends more, so keep calling until we process the
// entire initial buffer before we continue so we do not miss any bytes.
processCommand(initBuffer.buffer);
} while (initBuffer.buffer.hasRemaining());
initBuffer.buffer.clear();
openWireInititialized = true;
openWireInitialized = true;
}
}
}
@ -277,7 +285,7 @@ public class NIOSSLTransport extends NIOTransport {
break;
}
receiveCounter += readCount;
receiveCounter.addAndGet(readCount);
}
if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {

View File

@ -124,7 +124,7 @@ public class NIOTransport extends TcpTransport {
break;
}
this.receiveCounter += readSize;
this.receiveCounter.addAndGet(readSize);
if (currentBuffer.hasRemaining()) {
continue;
}

View File

@ -34,6 +34,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.SocketFactory;
@ -130,8 +131,8 @@ public class TcpTransport extends TransportThreadSupport implements Transport, S
protected boolean useLocalHost = false;
protected int minmumWireFormatVersion;
protected SocketFactory socketFactory;
protected final AtomicReference<CountDownLatch> stoppedLatch = new AtomicReference<CountDownLatch>();
protected volatile int receiveCounter;
protected final AtomicReference<CountDownLatch> stoppedLatch = new AtomicReference<>();
protected final AtomicInteger receiveCounter = new AtomicInteger();
protected Map<String, Object> socketOptions;
private int soLinger = Integer.MIN_VALUE;
@ -615,22 +616,22 @@ public class TcpTransport extends TransportThreadSupport implements Transport, S
TcpBufferedInputStream buffIn = new TcpBufferedInputStream(socket.getInputStream(), ioBufferSize) {
@Override
public int read() throws IOException {
receiveCounter++;
receiveCounter.incrementAndGet();
return super.read();
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
receiveCounter++;
receiveCounter.incrementAndGet();
return super.read(b, off, len);
}
@Override
public long skip(long n) throws IOException {
receiveCounter++;
receiveCounter.incrementAndGet();
return super.skip(n);
}
@Override
protected void fill() throws IOException {
receiveCounter++;
receiveCounter.incrementAndGet();
super.fill();
}
};
@ -684,7 +685,7 @@ public class TcpTransport extends TransportThreadSupport implements Transport, S
@Override
public int getReceiveCounter() {
return receiveCounter;
return receiveCounter.get();
}
public static class InitBuffer {

View File

@ -70,7 +70,7 @@ public class MQTTNIOSSLTransport extends NIOSSLTransport {
protected void doInit() throws Exception {
if (initBuffer != null) {
nextFrameSize = -1;
receiveCounter += initBuffer.readSize;
receiveCounter.addAndGet(initBuffer.readSize);
initBuffer.buffer.flip();
processCommand(initBuffer.buffer);
}
@ -78,4 +78,4 @@ public class MQTTNIOSSLTransport extends NIOSSLTransport {
}
}
}

View File

@ -131,7 +131,7 @@ public class MQTTNIOTransport extends TcpTransport {
DataByteArrayInputStream dis = new DataByteArrayInputStream(buffer.array());
codec.parse(dis, readSize);
receiveCounter += readSize;
receiveCounter.addAndGet(readSize);
// clear the buffer
buffer.clear();
@ -154,4 +154,4 @@ public class MQTTNIOTransport extends TcpTransport {
super.doStop(stopper);
}
}
}
}

View File

@ -99,7 +99,7 @@ public class StompNIOSSLTransport extends NIOSSLTransport {
protected void doInit() throws Exception {
if (initBuffer != null) {
nextFrameSize = -1;
receiveCounter += initBuffer.readSize;
receiveCounter.addAndGet(initBuffer.readSize);
initBuffer.buffer.flip();
processCommand(initBuffer.buffer);
}

View File

@ -128,7 +128,7 @@ public class StompNIOTransport extends TcpTransport {
}
protected void processBuffer(ByteBuffer buffer, int readSize) throws Exception {
receiveCounter += readSize;
receiveCounter.addAndGet(readSize);
buffer.flip();