Improve configuration for disabling non-SASL connections.
(cherry picked from commit c49db029ab)
This commit is contained in:
Timothy Bish 2016-06-09 17:32:41 -04:00
parent d594248db5
commit f471b51c2a
3 changed files with 71 additions and 2 deletions

View File

@ -40,6 +40,7 @@ public class AmqpWireFormat implements WireFormat {
public static final int DEFAULT_CONNECTION_TIMEOUT = 30000;
public static final int DEFAULT_IDLE_TIMEOUT = 30000;
public static final int DEFAULT_PRODUCER_CREDIT = 1000;
public static final boolean DEFAULT_ALLOW_NON_SASL_CONNECTIONS = true;
private static final int SASL_PROTOCOL = 3;
@ -50,6 +51,7 @@ public class AmqpWireFormat implements WireFormat {
private int idelTimeout = DEFAULT_IDLE_TIMEOUT;
private int producerCredit = DEFAULT_PRODUCER_CREDIT;
private String transformer = InboundTransformer.TRANSFORMER_JMS;
private boolean allowNonSaslConnections = DEFAULT_ALLOW_NON_SASL_CONNECTIONS;
private boolean magicRead = false;
private ResetListener resetListener;
@ -58,8 +60,6 @@ public class AmqpWireFormat implements WireFormat {
void onProtocolReset();
}
private boolean allowNonSaslConnections = true;
@Override
public ByteSequence marshal(Object command) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -131,6 +131,10 @@ public class AmqpWireFormat implements WireFormat {
return false;
}
if (!(header.getProtocolId() == 0 || header.getProtocolId() == 3)) {
return false;
}
if (!isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) {
return false;
}

View File

@ -30,6 +30,7 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
private int idelTimeout = AmqpWireFormat.DEFAULT_IDLE_TIMEOUT;
private int producerCredit = AmqpWireFormat.DEFAULT_PRODUCER_CREDIT;
private String transformer = InboundTransformer.TRANSFORMER_NATIVE;
private boolean allowNonSaslConnections = AmqpWireFormat.DEFAULT_ALLOW_NON_SASL_CONNECTIONS;
@Override
public WireFormat createWireFormat() {
@ -40,6 +41,7 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
wireFormat.setIdleTimeout(getIdelTimeout());
wireFormat.setProducerCredit(getProducerCredit());
wireFormat.setTransformer(getTransformer());
wireFormat.setAllowNonSaslConnections(isAllowNonSaslConnections());
return wireFormat;
}
@ -83,4 +85,12 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
public void setTransformer(String transformer) {
this.transformer = transformer;
}
public boolean isAllowNonSaslConnections() {
return allowNonSaslConnections;
}
public void setAllowNonSaslConnections(boolean allowNonSaslConnections) {
this.allowNonSaslConnections = allowNonSaslConnections;
}
}

View File

@ -60,11 +60,17 @@ public class UnsupportedClientTest extends AmqpTestSupport {
super.setUp();
}
@Override
public String getAdditionalConfig() {
return "&wireFormat.allowNonSaslConnections=false";
}
@Test(timeout = 60000)
public void testOlderProtocolIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(0);
header.setMinor(9);
header.setRevision(1);
@ -87,6 +93,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(2);
header.setMinor(0);
header.setRevision(0);
@ -109,6 +116,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(1);
header.setMinor(1);
header.setRevision(0);
@ -131,6 +139,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(1);
header.setMinor(0);
header.setRevision(1);
@ -148,6 +157,52 @@ public class UnsupportedClientTest extends AmqpTestSupport {
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testNonSaslClientIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(0);
header.setMajor(1);
header.setMinor(0);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(amqpPort, header, false);
// Test SSL
doTestInvalidHeaderProcessing(amqpSslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(amqpNioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testUnkownProtocolIdIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(5);
header.setMajor(1);
header.setMinor(0);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(amqpPort, header, false);
// Test SSL
doTestInvalidHeaderProcessing(amqpSslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(amqpNioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testInvalidProtocolHeader() throws Exception {