Improve configuration for disabling non-SASL connections.
(cherry picked from commit c49db029ab)
This commit is contained in:
Timothy Bish 2016-06-09 17:32:41 -04:00
parent d594248db5
commit f471b51c2a
3 changed files with 71 additions and 2 deletions

View File

@ -40,6 +40,7 @@ public class AmqpWireFormat implements WireFormat {
public static final int DEFAULT_CONNECTION_TIMEOUT = 30000; public static final int DEFAULT_CONNECTION_TIMEOUT = 30000;
public static final int DEFAULT_IDLE_TIMEOUT = 30000; public static final int DEFAULT_IDLE_TIMEOUT = 30000;
public static final int DEFAULT_PRODUCER_CREDIT = 1000; public static final int DEFAULT_PRODUCER_CREDIT = 1000;
public static final boolean DEFAULT_ALLOW_NON_SASL_CONNECTIONS = true;
private static final int SASL_PROTOCOL = 3; private static final int SASL_PROTOCOL = 3;
@ -50,6 +51,7 @@ public class AmqpWireFormat implements WireFormat {
private int idelTimeout = DEFAULT_IDLE_TIMEOUT; private int idelTimeout = DEFAULT_IDLE_TIMEOUT;
private int producerCredit = DEFAULT_PRODUCER_CREDIT; private int producerCredit = DEFAULT_PRODUCER_CREDIT;
private String transformer = InboundTransformer.TRANSFORMER_JMS; private String transformer = InboundTransformer.TRANSFORMER_JMS;
private boolean allowNonSaslConnections = DEFAULT_ALLOW_NON_SASL_CONNECTIONS;
private boolean magicRead = false; private boolean magicRead = false;
private ResetListener resetListener; private ResetListener resetListener;
@ -58,8 +60,6 @@ public class AmqpWireFormat implements WireFormat {
void onProtocolReset(); void onProtocolReset();
} }
private boolean allowNonSaslConnections = true;
@Override @Override
public ByteSequence marshal(Object command) throws IOException { public ByteSequence marshal(Object command) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -131,6 +131,10 @@ public class AmqpWireFormat implements WireFormat {
return false; return false;
} }
if (!(header.getProtocolId() == 0 || header.getProtocolId() == 3)) {
return false;
}
if (!isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) { if (!isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) {
return false; return false;
} }

View File

@ -30,6 +30,7 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
private int idelTimeout = AmqpWireFormat.DEFAULT_IDLE_TIMEOUT; private int idelTimeout = AmqpWireFormat.DEFAULT_IDLE_TIMEOUT;
private int producerCredit = AmqpWireFormat.DEFAULT_PRODUCER_CREDIT; private int producerCredit = AmqpWireFormat.DEFAULT_PRODUCER_CREDIT;
private String transformer = InboundTransformer.TRANSFORMER_NATIVE; private String transformer = InboundTransformer.TRANSFORMER_NATIVE;
private boolean allowNonSaslConnections = AmqpWireFormat.DEFAULT_ALLOW_NON_SASL_CONNECTIONS;
@Override @Override
public WireFormat createWireFormat() { public WireFormat createWireFormat() {
@ -40,6 +41,7 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
wireFormat.setIdleTimeout(getIdelTimeout()); wireFormat.setIdleTimeout(getIdelTimeout());
wireFormat.setProducerCredit(getProducerCredit()); wireFormat.setProducerCredit(getProducerCredit());
wireFormat.setTransformer(getTransformer()); wireFormat.setTransformer(getTransformer());
wireFormat.setAllowNonSaslConnections(isAllowNonSaslConnections());
return wireFormat; return wireFormat;
} }
@ -83,4 +85,12 @@ public class AmqpWireFormatFactory implements WireFormatFactory {
public void setTransformer(String transformer) { public void setTransformer(String transformer) {
this.transformer = transformer; this.transformer = transformer;
} }
public boolean isAllowNonSaslConnections() {
return allowNonSaslConnections;
}
public void setAllowNonSaslConnections(boolean allowNonSaslConnections) {
this.allowNonSaslConnections = allowNonSaslConnections;
}
} }

View File

@ -60,11 +60,17 @@ public class UnsupportedClientTest extends AmqpTestSupport {
super.setUp(); super.setUp();
} }
@Override
public String getAdditionalConfig() {
return "&wireFormat.allowNonSaslConnections=false";
}
@Test(timeout = 60000) @Test(timeout = 60000)
public void testOlderProtocolIsRejected() throws Exception { public void testOlderProtocolIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader(); AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(0); header.setMajor(0);
header.setMinor(9); header.setMinor(9);
header.setRevision(1); header.setRevision(1);
@ -87,6 +93,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader(); AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(2); header.setMajor(2);
header.setMinor(0); header.setMinor(0);
header.setRevision(0); header.setRevision(0);
@ -109,6 +116,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader(); AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(1); header.setMajor(1);
header.setMinor(1); header.setMinor(1);
header.setRevision(0); header.setRevision(0);
@ -131,6 +139,7 @@ public class UnsupportedClientTest extends AmqpTestSupport {
AmqpHeader header = new AmqpHeader(); AmqpHeader header = new AmqpHeader();
header.setProtocolId(3);
header.setMajor(1); header.setMajor(1);
header.setMinor(0); header.setMinor(0);
header.setRevision(1); header.setRevision(1);
@ -148,6 +157,52 @@ public class UnsupportedClientTest extends AmqpTestSupport {
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true); doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
} }
@Test(timeout = 60000)
public void testNonSaslClientIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(0);
header.setMajor(1);
header.setMinor(0);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(amqpPort, header, false);
// Test SSL
doTestInvalidHeaderProcessing(amqpSslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(amqpNioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testUnkownProtocolIdIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setProtocolId(5);
header.setMajor(1);
header.setMinor(0);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(amqpPort, header, false);
// Test SSL
doTestInvalidHeaderProcessing(amqpSslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(amqpNioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(amqpNioPlusSslPort, header, true);
}
@Test(timeout = 60000) @Test(timeout = 60000)
public void testInvalidProtocolHeader() throws Exception { public void testInvalidProtocolHeader() throws Exception {