SEC-1012: Refactor SessionRegistry interface to use Java 5 generics.

This commit is contained in:
Luke Taylor 2009-04-21 06:57:21 +00:00
parent b03e4f435b
commit ba6664f77f
4 changed files with 48 additions and 43 deletions

View File

@ -15,6 +15,8 @@
package org.springframework.security.authentication.concurrent; package org.springframework.security.authentication.concurrent;
import java.util.List;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.SpringSecurityMessageSource;
@ -29,8 +31,7 @@ import org.springframework.util.Assert;
/** /**
* Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins.<p>By default * Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins.
* uses {@link SessionRegistryImpl}, although any <code>SessionRegistry</code> may be used.</p>
* *
* @author Ben Alex * @author Ben Alex
* @version $Id$ * @version $Id$
@ -61,9 +62,9 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
* @param allowableSessions DOCUMENT ME! * @param allowableSessions DOCUMENT ME!
* @param registry an instance of the <code>SessionRegistry</code> for subclass use * @param registry an instance of the <code>SessionRegistry</code> for subclass use
* *
* @throws ConcurrentLoginException DOCUMENT ME! * @throws ConcurrentLoginException if the
*/ */
protected void allowableSessionsExceeded(String sessionId, SessionInformation[] sessions, int allowableSessions, protected void allowableSessionsExceeded(String sessionId, List<SessionInformation> sessions, int allowableSessions,
SessionRegistry registry) { SessionRegistry registry) {
if (exceptionIfMaximumExceeded || (sessions == null)) { if (exceptionIfMaximumExceeded || (sessions == null)) {
throw new ConcurrentLoginException(messages.getMessage("ConcurrentSessionControllerImpl.exceededAllowed", throw new ConcurrentLoginException(messages.getMessage("ConcurrentSessionControllerImpl.exceededAllowed",
@ -74,30 +75,25 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
// Determine least recently used session, and mark it for invalidation // Determine least recently used session, and mark it for invalidation
SessionInformation leastRecentlyUsed = null; SessionInformation leastRecentlyUsed = null;
for (int i = 0; i < sessions.length; i++) { for (int i = 0; i < sessions.size(); i++) {
if ((leastRecentlyUsed == null) if ((leastRecentlyUsed == null)
|| sessions[i].getLastRequest().before(leastRecentlyUsed.getLastRequest())) { || sessions.get(i).getLastRequest().before(leastRecentlyUsed.getLastRequest())) {
leastRecentlyUsed = sessions[i]; leastRecentlyUsed = sessions.get(i);
} }
} }
leastRecentlyUsed.expireNow(); leastRecentlyUsed.expireNow();
} }
public void checkAuthenticationAllowed(Authentication request) public void checkAuthenticationAllowed(Authentication request) throws AuthenticationException {
throws AuthenticationException {
Assert.notNull(request, "Authentication request cannot be null (violation of interface contract)"); Assert.notNull(request, "Authentication request cannot be null (violation of interface contract)");
Object principal = SessionRegistryUtils.obtainPrincipalFromAuthentication(request); Object principal = SessionRegistryUtils.obtainPrincipalFromAuthentication(request);
String sessionId = SessionRegistryUtils.obtainSessionIdFromAuthentication(request); String sessionId = SessionRegistryUtils.obtainSessionIdFromAuthentication(request);
SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); final List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
int sessionCount = 0; int sessionCount = sessions == null ? 0 : sessions.size();
if (sessions != null) {
sessionCount = sessions.length;
}
int allowableSessions = getMaximumSessionsForThisUser(request); int allowableSessions = getMaximumSessionsForThisUser(request);
Assert.isTrue(allowableSessions != 0, "getMaximumSessionsForThisUser() must return either -1 to allow " Assert.isTrue(allowableSessions != 0, "getMaximumSessionsForThisUser() must return either -1 to allow "
@ -106,13 +102,17 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
if (sessionCount < allowableSessions) { if (sessionCount < allowableSessions) {
// They haven't got too many login sessions running at present // They haven't got too many login sessions running at present
return; return;
} else if (allowableSessions == -1) { }
if (allowableSessions == -1) {
// We permit unlimited logins // We permit unlimited logins
return; return;
} else if (sessionCount == allowableSessions) { }
if (sessionCount == allowableSessions) {
// Only permit it though if this request is associated with one of the sessions // Only permit it though if this request is associated with one of the sessions
for (int i = 0; i < sessionCount; i++) { for (SessionInformation si : sessions) {
if (sessions[i].getSessionId().equals(sessionId)) { if (si.getSessionId().equals(sessionId)) {
return; return;
} }
} }

View File

@ -15,6 +15,8 @@
package org.springframework.security.authentication.concurrent; package org.springframework.security.authentication.concurrent;
import java.util.List;
/** /**
* Maintains a registry of <code>SessionInformation</code> instances. * Maintains a registry of <code>SessionInformation</code> instances.
* *
@ -29,7 +31,7 @@ public interface SessionRegistry {
* *
* @return each of the unique principals, which can then be presented to {@link #getAllSessions(Object, boolean)}. * @return each of the unique principals, which can then be presented to {@link #getAllSessions(Object, boolean)}.
*/ */
Object[] getAllPrincipals(); List<Object> getAllPrincipals();
/** /**
* Obtains all the known sessions for the specified principal. Sessions that have been destroyed are not * Obtains all the known sessions for the specified principal. Sessions that have been destroyed are not
@ -41,7 +43,7 @@ public interface SessionRegistry {
* *
* @return the matching sessions for this principal, or <code>null</code> if none were found * @return the matching sessions for this principal, or <code>null</code> if none were found
*/ */
SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions); List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions);
/** /**
* Obtains the session information for the specified <code>sessionId</code>. Even expired sessions are * Obtains the session information for the specified <code>sessionId</code>. Even expired sessions are

View File

@ -16,6 +16,7 @@
package org.springframework.security.authentication.concurrent; package org.springframework.security.authentication.concurrent;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
@ -57,18 +58,18 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
// ~ Methods ======================================================================================================= // ~ Methods =======================================================================================================
public Object[] getAllPrincipals() { public List<Object> getAllPrincipals() {
return principals.keySet().toArray(); return Arrays.asList(principals.keySet().toArray());
} }
public SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions) { public List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions) {
Set<String> sessionsUsedByPrincipal = principals.get(principal); final Set<String> sessionsUsedByPrincipal = principals.get(principal);
if (sessionsUsedByPrincipal == null) { if (sessionsUsedByPrincipal == null) {
return null; return null;
} }
List<SessionInformation> list = new ArrayList<SessionInformation>(); List<SessionInformation> list = new ArrayList<SessionInformation>(sessionsUsedByPrincipal.size());
synchronized (sessionsUsedByPrincipal) { synchronized (sessionsUsedByPrincipal) {
for (String sessionId : sessionsUsedByPrincipal) { for (String sessionId : sessionsUsedByPrincipal) {
@ -84,7 +85,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
} }
} }
return (SessionInformation[]) list.toArray(new SessionInformation[0]); return list;
} }
public SessionInformation getSessionInformation(String sessionId) { public SessionInformation getSessionInformation(String sessionId) {

View File

@ -18,6 +18,7 @@ package org.springframework.security.authentication.concurrent;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import java.util.Date; import java.util.Date;
import java.util.List;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -77,8 +78,9 @@ public class SessionRegistryImplTests {
sessionRegistry.registerNewSession(sessionId2, principal1); sessionRegistry.registerNewSession(sessionId2, principal1);
sessionRegistry.registerNewSession(sessionId3, principal2); sessionRegistry.registerNewSession(sessionId3, principal2);
assertEquals(principal1, sessionRegistry.getAllPrincipals()[0]); assertEquals(2, sessionRegistry.getAllPrincipals().size());
assertEquals(principal2, sessionRegistry.getAllPrincipals()[1]); assertTrue(sessionRegistry.getAllPrincipals().contains(principal1));
assertTrue(sessionRegistry.getAllPrincipals().contains(principal2));
} }
@Test @Test
@ -95,7 +97,7 @@ public class SessionRegistryImplTests {
assertNotNull(sessionRegistry.getSessionInformation(sessionId).getLastRequest()); assertNotNull(sessionRegistry.getSessionInformation(sessionId).getLastRequest());
// Retrieve existing session by principal // Retrieve existing session by principal
assertEquals(1, sessionRegistry.getAllSessions(principal, false).length); assertEquals(1, sessionRegistry.getAllSessions(principal, false).size());
// Sleep to ensure SessionRegistryImpl will update time // Sleep to ensure SessionRegistryImpl will update time
Thread.sleep(1000); Thread.sleep(1000);
@ -107,7 +109,7 @@ public class SessionRegistryImplTests {
assertTrue(retrieved.after(currentDateTime)); assertTrue(retrieved.after(currentDateTime));
// Check it retrieves correctly when looked up via principal // Check it retrieves correctly when looked up via principal
assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false)[0].getLastRequest()); assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false).get(0).getLastRequest());
// Clear session information // Clear session information
sessionRegistry.removeSessionInformation(sessionId); sessionRegistry.removeSessionInformation(sessionId);
@ -124,13 +126,13 @@ public class SessionRegistryImplTests {
String sessionId2 = "9876543210"; String sessionId2 = "9876543210";
sessionRegistry.registerNewSession(sessionId1, principal); sessionRegistry.registerNewSession(sessionId1, principal);
SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
assertEquals(1, sessions.length); assertEquals(1, sessions.size());
assertTrue(contains(sessionId1, principal)); assertTrue(contains(sessionId1, principal));
sessionRegistry.registerNewSession(sessionId2, principal); sessionRegistry.registerNewSession(sessionId2, principal);
sessions = sessionRegistry.getAllSessions(principal, false); sessions = sessionRegistry.getAllSessions(principal, false);
assertEquals(2, sessions.length); assertEquals(2, sessions.size());
assertTrue(contains(sessionId2, principal)); assertTrue(contains(sessionId2, principal));
// Expire one session // Expire one session
@ -149,18 +151,18 @@ public class SessionRegistryImplTests {
String sessionId2 = "9876543210"; String sessionId2 = "9876543210";
sessionRegistry.registerNewSession(sessionId1, principal); sessionRegistry.registerNewSession(sessionId1, principal);
SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false); List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
assertEquals(1, sessions.length); assertEquals(1, sessions.size());
assertTrue(contains(sessionId1, principal)); assertTrue(contains(sessionId1, principal));
sessionRegistry.registerNewSession(sessionId2, principal); sessionRegistry.registerNewSession(sessionId2, principal);
sessions = sessionRegistry.getAllSessions(principal, false); sessions = sessionRegistry.getAllSessions(principal, false);
assertEquals(2, sessions.length); assertEquals(2, sessions.size());
assertTrue(contains(sessionId2, principal)); assertTrue(contains(sessionId2, principal));
sessionRegistry.removeSessionInformation(sessionId1); sessionRegistry.removeSessionInformation(sessionId1);
sessions = sessionRegistry.getAllSessions(principal, false); sessions = sessionRegistry.getAllSessions(principal, false);
assertEquals(1, sessions.length); assertEquals(1, sessions.size());
assertTrue(contains(sessionId2, principal)); assertTrue(contains(sessionId2, principal));
sessionRegistry.removeSessionInformation(sessionId2); sessionRegistry.removeSessionInformation(sessionId2);
@ -169,10 +171,10 @@ public class SessionRegistryImplTests {
} }
private boolean contains(String sessionId, Object principal) { private boolean contains(String sessionId, Object principal) {
SessionInformation[] info = sessionRegistry.getAllSessions(principal, false); List<SessionInformation> info = sessionRegistry.getAllSessions(principal, false);
for (int i = 0; i < info.length; i++) { for (int i = 0; i < info.size(); i++) {
if (sessionId.equals(info[i].getSessionId())) { if (sessionId.equals(info.get(i).getSessionId())) {
return true; return true;
} }
} }