From b62e9875fed987f341b946082f4e28866e54b979 Mon Sep 17 00:00:00 2001 From: gtully Date: Tue, 12 Dec 2017 14:04:06 +0000 Subject: [PATCH] [ARTEMIS-1552] ensure gssapi sasl mech can deal with empty receive buffer --- .../protocol/amqp/sasl/GSSAPIServerSASL.java | 5 +- .../integration/amqp/JMSSaslGssapiTest.java | 97 +++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/artemis-protocols/artemis-amqp-protocol/src/main/java/org/apache/activemq/artemis/protocol/amqp/sasl/GSSAPIServerSASL.java b/artemis-protocols/artemis-amqp-protocol/src/main/java/org/apache/activemq/artemis/protocol/amqp/sasl/GSSAPIServerSASL.java index e89d548959..c9b43fefa7 100644 --- a/artemis-protocols/artemis-amqp-protocol/src/main/java/org/apache/activemq/artemis/protocol/amqp/sasl/GSSAPIServerSASL.java +++ b/artemis-protocols/artemis-amqp-protocol/src/main/java/org/apache/activemq/artemis/protocol/amqp/sasl/GSSAPIServerSASL.java @@ -74,7 +74,10 @@ public class GSSAPIServerSASL implements ServerSASL { })); } - byte[] challenge = Subject.doAs(jaasId, (PrivilegedExceptionAction) () -> saslServer.evaluateResponse(bytes)); + byte[] challenge = null; + if (bytes.length > 0) { + challenge = Subject.doAs(jaasId, (PrivilegedExceptionAction) () -> saslServer.evaluateResponse(bytes)); + } if (saslServer.isComplete()) { result = new GSSAPISASLResult(true, new KerberosPrincipal(saslServer.getAuthorizationID())); } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/amqp/JMSSaslGssapiTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/amqp/JMSSaslGssapiTest.java index d66c83d728..5a93154325 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/amqp/JMSSaslGssapiTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/amqp/JMSSaslGssapiTest.java @@ -25,16 +25,34 @@ import javax.jms.TextMessage; import java.io.File; import java.net.URI; import java.net.URL; +import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnector; +import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants; import org.apache.activemq.artemis.core.security.Role; import org.apache.activemq.artemis.core.server.ActiveMQServer; +import org.apache.activemq.artemis.protocol.amqp.broker.ProtonProtocolManagerFactory; +import org.apache.activemq.artemis.protocol.amqp.client.AMQPClientConnectionFactory; +import org.apache.activemq.artemis.protocol.amqp.client.ProtonClientConnectionManager; +import org.apache.activemq.artemis.protocol.amqp.client.ProtonClientProtocolManager; +import org.apache.activemq.artemis.protocol.amqp.proton.handler.EventHandler; +import org.apache.activemq.artemis.protocol.amqp.proton.handler.ProtonHandler; +import org.apache.activemq.artemis.protocol.amqp.sasl.ClientSASL; +import org.apache.activemq.artemis.protocol.amqp.sasl.ClientSASLFactory; import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager; +import org.apache.activemq.artemis.tests.util.Wait; import org.apache.activemq.artemis.utils.RandomUtil; import org.apache.hadoop.minikdc.MiniKdc; import org.apache.qpid.jms.JmsConnectionFactory; +import org.apache.qpid.jms.sasl.GssapiMechanism; +import org.apache.qpid.proton.amqp.Symbol; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -164,4 +182,83 @@ public class JMSSaslGssapiTest extends JMSClientTestSupport { assertTrue(expected.getMessage().contains("SASL")); } } + + @Test + public void testOutboundWithSlowMech() throws Exception { + final Map config = new LinkedHashMap<>(); config.put(TransportConstants.HOST_PROP_NAME, "localhost"); + config.put(TransportConstants.PORT_PROP_NAME, String.valueOf(AMQP_PORT)); + final ClientSASLFactory clientSASLFactory = new ClientSASLFactory() { + @Override + public ClientSASL chooseMechanism(String[] availableMechanims) { + GssapiMechanism gssapiMechanism = new GssapiMechanism(); + return new ClientSASL() { + @Override + public String getName() { + return gssapiMechanism.getName(); + } + + @Override + public byte[] getInitialResponse() { + gssapiMechanism.setUsername("client"); + gssapiMechanism.setServerName("localhost"); + try { + return gssapiMechanism.getInitialResponse(); + } catch (Exception e) { + e.printStackTrace(); + } + return new byte[0]; + } + + @Override + public byte[] getResponse(byte[] challenge) { + try { + // simulate a slow client + TimeUnit.SECONDS.sleep(4); + } catch (InterruptedException e) { + e.printStackTrace(); + } + try { + return gssapiMechanism.getChallengeResponse(challenge); + } catch (Exception e) { + e.printStackTrace(); + } + return new byte[0]; + } + }; + } + }; + + final AtomicBoolean connectionOpened = new AtomicBoolean(); + final AtomicBoolean authFailed = new AtomicBoolean(); + + EventHandler eventHandler = new EventHandler() { + @Override + public void onRemoteOpen(org.apache.qpid.proton.engine.Connection connection) throws Exception { + connectionOpened.set(true); + } + + @Override + public void onAuthFailed(ProtonHandler protonHandler, org.apache.qpid.proton.engine.Connection connection) { + authFailed.set(true); + } + }; + + ProtonClientConnectionManager lifeCycleListener = new ProtonClientConnectionManager(new AMQPClientConnectionFactory(server, "myid", Collections.singletonMap(Symbol.getSymbol("myprop"), "propvalue"), 5000), Optional.of(eventHandler), clientSASLFactory); + ProtonClientProtocolManager protocolManager = new ProtonClientProtocolManager(new ProtonProtocolManagerFactory(), server); + NettyConnector connector = new NettyConnector(config, lifeCycleListener, lifeCycleListener, server.getExecutorFactory().getExecutor(), server.getExecutorFactory().getExecutor(), server.getScheduledPool(), protocolManager); + connector.start(); + connector.createConnection(); + + try { + Wait.assertEquals(1, server::getConnectionCount); + Wait.assertTrue(connectionOpened::get); + Wait.assertFalse(authFailed::get); + + lifeCycleListener.stop(); + + Wait.assertEquals(0, server::getConnectionCount); + } finally { + lifeCycleListener.stop(); + } + } }