From 95ab95b6e328db2c7188caaf2f071f18f6f0133c Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Mon, 27 Apr 2009 07:12:12 +0000 Subject: [PATCH] SEC-1078: Missed commit of default strategy class. --- .../DefaultWASUsernameAndGroupsExtractor.java | 205 ++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100755 web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java new file mode 100755 index 0000000000..aeced5c008 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/websphere/DefaultWASUsernameAndGroupsExtractor.java @@ -0,0 +1,205 @@ +package org.springframework.security.web.authentication.preauth.websphere; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import javax.naming.Context; +import javax.naming.InitialContext; +import javax.naming.NamingException; +import javax.rmi.PortableRemoteObject; +import javax.security.auth.Subject; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * WebSphere Security helper class to allow retrieval of the current username and groups. + *

+ * See Spring Security Jira SEC-477. + * + * @author Ruud Senden + * @author Stephane Manciot + * @since 2.0 + */ +final class DefaultWASUsernameAndGroupsExtractor implements WASUsernameAndGroupsExtractor { + private static final Log logger = LogFactory.getLog(DefaultWASUsernameAndGroupsExtractor.class); + + private static final String USER_REGISTRY = "UserRegistry"; + + private static Method getRunAsSubject = null; + + private static Method getGroupsForUser = null; + + private static Method getSecurityName = null; + + // SEC-803 + private static Class wsCredentialClass = null; + + public final List getGroupsForCurrentUser() { + return getWebSphereGroups(getRunAsSubject()); + } + + public final String getCurrentUserName() { + return getSecurityName(getRunAsSubject()); + } + + /** + * Get the security name for the given subject. + * + * @param subject + * The subject for which to retrieve the security name + * @return String the security name for the given subject + */ + private static final String getSecurityName(final Subject subject) { + if (logger.isDebugEnabled()) { + logger.debug("Determining Websphere security name for subject " + subject); + } + String userSecurityName = null; + if (subject != null) { + // SEC-803 + Object credential = subject.getPublicCredentials(getWSCredentialClass()).iterator().next(); + if (credential != null) { + userSecurityName = (String)invokeMethod(getSecurityNameMethod(),credential,null); + } + } + if (logger.isDebugEnabled()) { + logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject); + } + return userSecurityName; + } + + /** + * Get the current RunAs subject. + * + * @return Subject the current RunAs subject + */ + private static final Subject getRunAsSubject() { + logger.debug("Retrieving WebSphere RunAs subject"); + // get Subject: WSSubject.getCallerSubject (); + return (Subject) invokeMethod(getRunAsSubjectMethod(), null, new Object[] {}); + } + + /** + * Get the WebSphere group names for the given subject. + * + * @param subject + * The subject for which to retrieve the WebSphere group names + * @return the WebSphere group names for the given subject + */ + private static final List getWebSphereGroups(final Subject subject) { + return getWebSphereGroups(getSecurityName(subject)); + } + + /** + * Get the WebSphere group names for the given security name. + * + * @param securityName + * The security name for which to retrieve the WebSphere group names + * @return the WebSphere group names for the given security name + */ + @SuppressWarnings("unchecked") + private static final List getWebSphereGroups(final String securityName) { + Context ic = null; + try { + // TODO: Cache UserRegistry object + ic = new InitialContext(); + Object objRef = ic.lookup(USER_REGISTRY); + Object userReg = PortableRemoteObject.narrow(objRef, Class.forName ("com.ibm.websphere.security.UserRegistry")); + if (logger.isDebugEnabled()) { + logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry " + userReg); + } + final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg, new Object[]{ securityName }); + if (logger.isDebugEnabled()) { + logger.debug("Groups for user " + securityName + ": " + groups.toString()); + } + + return new ArrayList(groups); + } catch (Exception e) { + logger.error("Exception occured while looking up groups for user", e); + throw new RuntimeException("Exception occured while looking up groups for user", e); + } finally { + try { + ic.close(); + } catch (NamingException e) { + logger.debug("Exception occured while closing context", e); + } + } + } + + private static final Object invokeMethod(Method method, Object instance, Object[] args) + { + try { + return method.invoke(instance,args); + } catch (IllegalArgumentException e) { + logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+ Arrays.asList(args)+")",e); + throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); + } catch (IllegalAccessException e) { + logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); + throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); + } catch (InvocationTargetException e) { + logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); + throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); + } + } + + private static final Method getMethod(String className, String methodName, String[] parameterTypeNames) { + try { + Class c = Class.forName(className); + final int len = parameterTypeNames.length; + Class[] parameterTypes = new Class[len]; + for (int i = 0; i < len; i++) { + parameterTypes[i] = Class.forName(parameterTypeNames[i]); + } + return c.getDeclaredMethod(methodName, parameterTypes); + } catch (ClassNotFoundException e) { + logger.error("Required class"+className+" not found"); + throw new RuntimeException("Required class"+className+" not found",e); + } catch (NoSuchMethodException e) { + logger.error("Required method "+methodName+" with parameter types ("+ Arrays.asList(parameterTypeNames) +") not found on class "+className); + throw new RuntimeException("Required class"+className+" not found",e); + } + } + + private static final Method getRunAsSubjectMethod() { + if (getRunAsSubject == null) { + getRunAsSubject = getMethod("com.ibm.websphere.security.auth.WSSubject", "getRunAsSubject", new String[] {}); + } + return getRunAsSubject; + } + + private static final Method getGroupsForUserMethod() { + if (getGroupsForUser == null) { + getGroupsForUser = getMethod("com.ibm.websphere.security.UserRegistry", "getGroupsForUser", new String[] { "java.lang.String" }); + } + return getGroupsForUser; + } + + private static final Method getSecurityNameMethod() { + if (getSecurityName == null) { + getSecurityName = getMethod("com.ibm.websphere.security.cred.WSCredential", "getSecurityName", new String[] {}); + } + return getSecurityName; + } + + // SEC-803 + private static final Class getWSCredentialClass() { + if (wsCredentialClass == null) { + wsCredentialClass = getClass("com.ibm.websphere.security.cred.WSCredential"); + } + return wsCredentialClass; + } + + private static final Class getClass(String className) { + try { + return Class.forName(className); + } catch (ClassNotFoundException e) { + logger.error("Required class " + className + " not found"); + throw new RuntimeException("Required class " + className + " not found",e); + } + } + +}