diff --git a/core/src/main/java/org/springframework/security/authentication/concurrent/ConcurrentSessionControllerImpl.java b/core/src/main/java/org/springframework/security/authentication/concurrent/ConcurrentSessionControllerImpl.java index 10df5b2ca2..e4de3a14c5 100644 --- a/core/src/main/java/org/springframework/security/authentication/concurrent/ConcurrentSessionControllerImpl.java +++ b/core/src/main/java/org/springframework/security/authentication/concurrent/ConcurrentSessionControllerImpl.java @@ -15,6 +15,8 @@ package org.springframework.security.authentication.concurrent; +import java.util.List; + import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.SpringSecurityMessageSource; @@ -29,14 +31,13 @@ import org.springframework.util.Assert; /** - * Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins.

By default - * uses {@link SessionRegistryImpl}, although any SessionRegistry may be used.

+ * Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins. * * @author Ben Alex * @version $Id$ */ public class ConcurrentSessionControllerImpl implements ConcurrentSessionController, InitializingBean, - MessageSourceAware { + MessageSourceAware { //~ Instance fields ================================================================================================ protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); @@ -61,10 +62,10 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl * @param allowableSessions DOCUMENT ME! * @param registry an instance of the SessionRegistry for subclass use * - * @throws ConcurrentLoginException DOCUMENT ME! + * @throws ConcurrentLoginException if the */ - protected void allowableSessionsExceeded(String sessionId, SessionInformation[] sessions, int allowableSessions, - SessionRegistry registry) { + protected void allowableSessionsExceeded(String sessionId, List sessions, int allowableSessions, + SessionRegistry registry) { if (exceptionIfMaximumExceeded || (sessions == null)) { throw new ConcurrentLoginException(messages.getMessage("ConcurrentSessionControllerImpl.exceededAllowed", new Object[] {new Integer(allowableSessions)}, @@ -74,30 +75,25 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl // Determine least recently used session, and mark it for invalidation SessionInformation leastRecentlyUsed = null; - for (int i = 0; i < sessions.length; i++) { + for (int i = 0; i < sessions.size(); i++) { if ((leastRecentlyUsed == null) - || sessions[i].getLastRequest().before(leastRecentlyUsed.getLastRequest())) { - leastRecentlyUsed = sessions[i]; + || sessions.get(i).getLastRequest().before(leastRecentlyUsed.getLastRequest())) { + leastRecentlyUsed = sessions.get(i); } } leastRecentlyUsed.expireNow(); } - public void checkAuthenticationAllowed(Authentication request) - throws AuthenticationException { + public void checkAuthenticationAllowed(Authentication request) throws AuthenticationException { Assert.notNull(request, "Authentication request cannot be null (violation of interface contract)"); Object principal = SessionRegistryUtils.obtainPrincipalFromAuthentication(request); String sessionId = SessionRegistryUtils.obtainSessionIdFromAuthentication(request); - SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); + final List sessions = sessionRegistry.getAllSessions(principal, false); - int sessionCount = 0; - - if (sessions != null) { - sessionCount = sessions.length; - } + int sessionCount = sessions == null ? 0 : sessions.size(); int allowableSessions = getMaximumSessionsForThisUser(request); Assert.isTrue(allowableSessions != 0, "getMaximumSessionsForThisUser() must return either -1 to allow " @@ -106,13 +102,17 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl if (sessionCount < allowableSessions) { // They haven't got too many login sessions running at present return; - } else if (allowableSessions == -1) { + } + + if (allowableSessions == -1) { // We permit unlimited logins return; - } else if (sessionCount == allowableSessions) { + } + + if (sessionCount == allowableSessions) { // Only permit it though if this request is associated with one of the sessions - for (int i = 0; i < sessionCount; i++) { - if (sessions[i].getSessionId().equals(sessionId)) { + for (SessionInformation si : sessions) { + if (si.getSessionId().equals(sessionId)) { return; } } diff --git a/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistry.java b/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistry.java index 69b2ba66e3..347f3dd963 100644 --- a/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistry.java +++ b/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistry.java @@ -15,6 +15,8 @@ package org.springframework.security.authentication.concurrent; +import java.util.List; + /** * Maintains a registry of SessionInformation instances. * @@ -29,7 +31,7 @@ public interface SessionRegistry { * * @return each of the unique principals, which can then be presented to {@link #getAllSessions(Object, boolean)}. */ - Object[] getAllPrincipals(); + List getAllPrincipals(); /** * Obtains all the known sessions for the specified principal. Sessions that have been destroyed are not @@ -41,7 +43,7 @@ public interface SessionRegistry { * * @return the matching sessions for this principal, or null if none were found */ - SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions); + List getAllSessions(Object principal, boolean includeExpiredSessions); /** * Obtains the session information for the specified sessionId. Even expired sessions are diff --git a/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistryImpl.java b/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistryImpl.java index be9bdbf75b..b6cc1df8c6 100644 --- a/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistryImpl.java +++ b/core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistryImpl.java @@ -16,6 +16,7 @@ package org.springframework.security.authentication.concurrent; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashMap; @@ -57,18 +58,18 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener // ~ Methods ======================================================================================================= - public Object[] getAllPrincipals() { - return principals.keySet().toArray(); + public List getAllPrincipals() { + return Arrays.asList(principals.keySet().toArray()); } - public SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions) { - Set sessionsUsedByPrincipal = principals.get(principal); + public List getAllSessions(Object principal, boolean includeExpiredSessions) { + final Set sessionsUsedByPrincipal = principals.get(principal); if (sessionsUsedByPrincipal == null) { return null; } - List list = new ArrayList(); + List list = new ArrayList(sessionsUsedByPrincipal.size()); synchronized (sessionsUsedByPrincipal) { for (String sessionId : sessionsUsedByPrincipal) { @@ -84,7 +85,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener } } - return (SessionInformation[]) list.toArray(new SessionInformation[0]); + return list; } public SessionInformation getSessionInformation(String sessionId) { diff --git a/core/src/test/java/org/springframework/security/authentication/concurrent/SessionRegistryImplTests.java b/core/src/test/java/org/springframework/security/authentication/concurrent/SessionRegistryImplTests.java index c481887f8d..fc8e3632bc 100644 --- a/core/src/test/java/org/springframework/security/authentication/concurrent/SessionRegistryImplTests.java +++ b/core/src/test/java/org/springframework/security/authentication/concurrent/SessionRegistryImplTests.java @@ -18,6 +18,7 @@ package org.springframework.security.authentication.concurrent; import static org.junit.Assert.*; import java.util.Date; +import java.util.List; import org.junit.Before; import org.junit.Test; @@ -77,8 +78,9 @@ public class SessionRegistryImplTests { sessionRegistry.registerNewSession(sessionId2, principal1); sessionRegistry.registerNewSession(sessionId3, principal2); - assertEquals(principal1, sessionRegistry.getAllPrincipals()[0]); - assertEquals(principal2, sessionRegistry.getAllPrincipals()[1]); + assertEquals(2, sessionRegistry.getAllPrincipals().size()); + assertTrue(sessionRegistry.getAllPrincipals().contains(principal1)); + assertTrue(sessionRegistry.getAllPrincipals().contains(principal2)); } @Test @@ -95,7 +97,7 @@ public class SessionRegistryImplTests { assertNotNull(sessionRegistry.getSessionInformation(sessionId).getLastRequest()); // Retrieve existing session by principal - assertEquals(1, sessionRegistry.getAllSessions(principal, false).length); + assertEquals(1, sessionRegistry.getAllSessions(principal, false).size()); // Sleep to ensure SessionRegistryImpl will update time Thread.sleep(1000); @@ -107,7 +109,7 @@ public class SessionRegistryImplTests { assertTrue(retrieved.after(currentDateTime)); // Check it retrieves correctly when looked up via principal - assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false)[0].getLastRequest()); + assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false).get(0).getLastRequest()); // Clear session information sessionRegistry.removeSessionInformation(sessionId); @@ -124,13 +126,13 @@ public class SessionRegistryImplTests { String sessionId2 = "9876543210"; sessionRegistry.registerNewSession(sessionId1, principal); - SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); - assertEquals(1, sessions.length); + List sessions = sessionRegistry.getAllSessions(principal, false); + assertEquals(1, sessions.size()); assertTrue(contains(sessionId1, principal)); sessionRegistry.registerNewSession(sessionId2, principal); sessions = sessionRegistry.getAllSessions(principal, false); - assertEquals(2, sessions.length); + assertEquals(2, sessions.size()); assertTrue(contains(sessionId2, principal)); // Expire one session @@ -149,18 +151,18 @@ public class SessionRegistryImplTests { String sessionId2 = "9876543210"; sessionRegistry.registerNewSession(sessionId1, principal); - SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); - assertEquals(1, sessions.length); + List sessions = sessionRegistry.getAllSessions(principal, false); + assertEquals(1, sessions.size()); assertTrue(contains(sessionId1, principal)); sessionRegistry.registerNewSession(sessionId2, principal); sessions = sessionRegistry.getAllSessions(principal, false); - assertEquals(2, sessions.length); + assertEquals(2, sessions.size()); assertTrue(contains(sessionId2, principal)); sessionRegistry.removeSessionInformation(sessionId1); sessions = sessionRegistry.getAllSessions(principal, false); - assertEquals(1, sessions.length); + assertEquals(1, sessions.size()); assertTrue(contains(sessionId2, principal)); sessionRegistry.removeSessionInformation(sessionId2); @@ -169,10 +171,10 @@ public class SessionRegistryImplTests { } private boolean contains(String sessionId, Object principal) { - SessionInformation[] info = sessionRegistry.getAllSessions(principal, false); + List info = sessionRegistry.getAllSessions(principal, false); - for (int i = 0; i < info.length; i++) { - if (sessionId.equals(info[i].getSessionId())) { + for (int i = 0; i < info.size(); i++) { + if (sessionId.equals(info.get(i).getSessionId())) { return true; } }