diff --git a/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java b/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java index 8d0e3307ec..566b8696f4 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java @@ -15,21 +15,16 @@ package org.springframework.security.core.session; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Date; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.context.ApplicationListener; import org.springframework.util.Assert; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArraySet; + /** * Default implementation of {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} * which listens for {@link org.springframework.security.core.session.SessionDestroyedEvent SessionDestroyedEvent}s @@ -49,14 +44,14 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener protected final Log logger = LogFactory.getLog(SessionRegistryImpl.class); /** */ - private final Map> principals = Collections.synchronizedMap(new HashMap>()); + private final ConcurrentMap> principals = new ConcurrentHashMap>(); /** */ - private final Map sessionIds = Collections.synchronizedMap(new HashMap()); + private final Map sessionIds = new ConcurrentHashMap(); //~ Methods ======================================================================================================== public List getAllPrincipals() { - return Arrays.asList(principals.keySet().toArray()); + return new ArrayList(principals.keySet()); } public List getAllSessions(Object principal, boolean includeExpiredSessions) { @@ -68,17 +63,15 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener List list = new ArrayList(sessionsUsedByPrincipal.size()); - synchronized (sessionsUsedByPrincipal) { - for (String sessionId : sessionsUsedByPrincipal) { - SessionInformation sessionInformation = getSessionInformation(sessionId); + for (String sessionId : sessionsUsedByPrincipal) { + SessionInformation sessionInformation = getSessionInformation(sessionId); - if (sessionInformation == null) { - continue; - } + if (sessionInformation == null) { + continue; + } - if (includeExpiredSessions || !sessionInformation.isExpired()) { - list.add(sessionInformation); - } + if (includeExpiredSessions || !sessionInformation.isExpired()) { + list.add(sessionInformation); } } @@ -88,7 +81,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener public SessionInformation getSessionInformation(String sessionId) { Assert.hasText(sessionId, "SessionId required as per interface contract"); - return (SessionInformation) sessionIds.get(sessionId); + return sessionIds.get(sessionId); } public void onApplicationEvent(SessionDestroyedEvent event) { @@ -106,7 +99,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener } } - public synchronized void registerNewSession(String sessionId, Object principal) { + public void registerNewSession(String sessionId, Object principal) { Assert.hasText(sessionId, "SessionId required as per interface contract"); Assert.notNull(principal, "Principal required as per interface contract"); @@ -123,8 +116,12 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener Set sessionsUsedByPrincipal = principals.get(principal); if (sessionsUsedByPrincipal == null) { - sessionsUsedByPrincipal = Collections.synchronizedSet(new HashSet(4)); - principals.put(principal, sessionsUsedByPrincipal); + sessionsUsedByPrincipal = new CopyOnWriteArraySet(); + Set prevSessionsUsedByPrincipal = principals.putIfAbsent(principal, + sessionsUsedByPrincipal); + if (prevSessionsUsedByPrincipal != null) { + sessionsUsedByPrincipal = prevSessionsUsedByPrincipal; + } } sessionsUsedByPrincipal.add(sessionId); @@ -159,20 +156,19 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener logger.debug("Removing session " + sessionId + " from principal's set of registered sessions"); } - synchronized (sessionsUsedByPrincipal) { - sessionsUsedByPrincipal.remove(sessionId); + sessionsUsedByPrincipal.remove(sessionId); - if (sessionsUsedByPrincipal.size() == 0) { - // No need to keep object in principals Map anymore - if (logger.isDebugEnabled()) { - logger.debug("Removing principal " + info.getPrincipal() + " from registry"); - } - principals.remove(info.getPrincipal()); + if (sessionsUsedByPrincipal.isEmpty()) { + // No need to keep object in principals Map anymore + if (logger.isDebugEnabled()) { + logger.debug("Removing principal " + info.getPrincipal() + " from registry"); } + principals.remove(info.getPrincipal()); } if (logger.isTraceEnabled()) { logger.trace("Sessions used by '" + info.getPrincipal() + "' : " + sessionsUsedByPrincipal); } } + }