SEC-803: Removed use of websphere SubjectHelper class.

This commit is contained in:
Luke Taylor 2008-05-14 22:51:39 +00:00
parent d4defb10fe
commit 6493df13f8

View File

@ -24,171 +24,182 @@ import org.apache.commons.logging.LogFactory;
* @since 2.0 * @since 2.0
*/ */
final class WASSecurityHelper { final class WASSecurityHelper {
private static final Log logger = LogFactory.getLog(WASSecurityHelper.class); private static final Log logger = LogFactory.getLog(WASSecurityHelper.class);
private static final String USER_REGISTRY = "UserRegistry"; private static final String USER_REGISTRY = "UserRegistry";
private static Method getRunAsSubject = null; private static Method getRunAsSubject = null;
private static Method getWSCredentialFromSubject = null; private static Method getGroupsForUser = null;
private static Method getGroupsForUser = null; private static Method getSecurityName = null;
private static Method getSecurityName = null; // SEC-803
private static Class wsCredentialClass = null;
/** /**
* Get the security name for the given subject. * Get the security name for the given subject.
* *
* @param subject * @param subject
* The subject for which to retrieve the security name * The subject for which to retrieve the security name
* @return String the security name for the given subject * @return String the security name for the given subject
*/ */
private static final String getSecurityName(final Subject subject) { private static final String getSecurityName(final Subject subject) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Determining Websphere security name for subject " + subject); logger.debug("Determining Websphere security name for subject " + subject);
} }
String userSecurityName = null; String userSecurityName = null;
if (subject != null) { if (subject != null) {
Object credential = invokeMethod(getWSCredentialFromSubjectMethod(),null,new Object[]{subject}); // SEC-803
if (credential != null) { Object credential = subject.getPublicCredentials(getWSCredentialClass()).iterator().next();
userSecurityName = (String)invokeMethod(getSecurityNameMethod(),credential,null); if (credential != null) {
} userSecurityName = (String)invokeMethod(getSecurityNameMethod(),credential,null);
} }
if (logger.isDebugEnabled()) { }
logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject); if (logger.isDebugEnabled()) {
} logger.debug("Websphere security name is " + userSecurityName + " for subject " + subject);
return userSecurityName; }
} return userSecurityName;
}
/** /**
* Get the current RunAs subject. * Get the current RunAs subject.
* *
* @return Subject the current RunAs subject * @return Subject the current RunAs subject
*/ */
private static final Subject getRunAsSubject() { private static final Subject getRunAsSubject() {
logger.debug("Retrieving WebSphere RunAs subject"); logger.debug("Retrieving WebSphere RunAs subject");
// get Subject: WSSubject.getCallerSubject (); // get Subject: WSSubject.getCallerSubject ();
return (Subject) invokeMethod(getRunAsSubjectMethod(), null, new Object[] {}); return (Subject) invokeMethod(getRunAsSubjectMethod(), null, new Object[] {});
} }
/** /**
* Get the WebSphere group names for the given subject. * Get the WebSphere group names for the given subject.
* *
* @param subject * @param subject
* The subject for which to retrieve the WebSphere group names * The subject for which to retrieve the WebSphere group names
* @return the WebSphere group names for the given subject * @return the WebSphere group names for the given subject
*/ */
private static final String[] getWebSphereGroups(final Subject subject) { private static final String[] getWebSphereGroups(final Subject subject) {
return getWebSphereGroups(getSecurityName(subject)); return getWebSphereGroups(getSecurityName(subject));
} }
/** /**
* Get the WebSphere group names for the given security name. * Get the WebSphere group names for the given security name.
* *
* @param securityName * @param securityName
* The securityname for which to retrieve the WebSphere group names * The securityname for which to retrieve the WebSphere group names
* @return the WebSphere group names for the given security name * @return the WebSphere group names for the given security name
*/ */
private static final String[] getWebSphereGroups(final String securityName) { private static final String[] getWebSphereGroups(final String securityName) {
Context ic = null; Context ic = null;
try { try {
// TODO: Cache UserRegistry object // TODO: Cache UserRegistry object
ic = new InitialContext(); ic = new InitialContext();
Object objRef = ic.lookup(USER_REGISTRY); Object objRef = ic.lookup(USER_REGISTRY);
Object userReg = PortableRemoteObject.narrow(objRef, Class.forName ("com.ibm.websphere.security.UserRegistry")); Object userReg = PortableRemoteObject.narrow(objRef, Class.forName ("com.ibm.websphere.security.UserRegistry"));
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry " + userReg); logger.debug("Determining WebSphere groups for user " + securityName + " using WebSphere UserRegistry " + userReg);
} }
final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg, new Object[]{ securityName }); final Collection groups = (Collection) invokeMethod(getGroupsForUserMethod(), userReg, new Object[]{ securityName });
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Groups for user " + securityName + ": " + groups.toString()); logger.debug("Groups for user " + securityName + ": " + groups.toString());
} }
String[] result = new String[groups.size()]; String[] result = new String[groups.size()];
return (String[]) groups.toArray(result); return (String[]) groups.toArray(result);
} catch (Exception e) { } catch (Exception e) {
logger.error("Exception occured while looking up groups for user", e); logger.error("Exception occured while looking up groups for user", e);
throw new RuntimeException("Exception occured while looking up groups for user", e); throw new RuntimeException("Exception occured while looking up groups for user", e);
} finally { } finally {
try { try {
ic.close(); ic.close();
} catch (NamingException e) { } catch (NamingException e) {
logger.debug("Exception occured while closing context", e); logger.debug("Exception occured while closing context", e);
} }
} }
} }
/** /**
* @return * @return
*/ */
public static final String[] getGroupsForCurrentUser() { public static final String[] getGroupsForCurrentUser() {
return getWebSphereGroups(getRunAsSubject()); return getWebSphereGroups(getRunAsSubject());
} }
public static final String getCurrentUserName() { public static final String getCurrentUserName() {
return getSecurityName(getRunAsSubject()); return getSecurityName(getRunAsSubject());
} }
private static final Object invokeMethod(Method method, Object instance, Object[] args) private static final Object invokeMethod(Method method, Object instance, Object[] args)
{ {
try { try {
return method.invoke(instance,args); return method.invoke(instance,args);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+ Arrays.asList(args)+")",e); logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+ Arrays.asList(args)+")",e);
throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e);
} catch (IllegalAccessException e) { } catch (IllegalAccessException e) {
logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e);
throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e);
} catch (InvocationTargetException e) { } catch (InvocationTargetException e) {
logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); logger.error("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e);
throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e); throw new RuntimeException("Error while invoking method "+method.getClass().getName()+"."+method.getName()+"("+Arrays.asList(args)+")",e);
} }
} }
private static final Method getMethod(String className, String methodName, String[] parameterTypeNames) { private static final Method getMethod(String className, String methodName, String[] parameterTypeNames) {
try { try {
Class c = Class.forName(className); Class c = Class.forName(className);
final int len = parameterTypeNames.length; final int len = parameterTypeNames.length;
Class[] parameterTypes = new Class[len]; Class[] parameterTypes = new Class[len];
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
parameterTypes[i] = Class.forName(parameterTypeNames[i]); parameterTypes[i] = Class.forName(parameterTypeNames[i]);
} }
return c.getDeclaredMethod(methodName, parameterTypes); return c.getDeclaredMethod(methodName, parameterTypes);
} catch (ClassNotFoundException e) { } catch (ClassNotFoundException e) {
logger.error("Required class"+className+" not found"); logger.error("Required class"+className+" not found");
throw new RuntimeException("Required class"+className+" not found",e); throw new RuntimeException("Required class"+className+" not found",e);
} catch (NoSuchMethodException e) { } catch (NoSuchMethodException e) {
logger.error("Required method "+methodName+" with parameter types ("+ Arrays.asList(parameterTypeNames) +") not found on class "+className); logger.error("Required method "+methodName+" with parameter types ("+ Arrays.asList(parameterTypeNames) +") not found on class "+className);
throw new RuntimeException("Required class"+className+" not found",e); throw new RuntimeException("Required class"+className+" not found",e);
} }
} }
private static final Method getRunAsSubjectMethod() { private static final Method getRunAsSubjectMethod() {
if (getRunAsSubject == null) { if (getRunAsSubject == null) {
getRunAsSubject = getMethod("com.ibm.websphere.security.auth.WSSubject", "getRunAsSubject", new String[] {}); getRunAsSubject = getMethod("com.ibm.websphere.security.auth.WSSubject", "getRunAsSubject", new String[] {});
} }
return getRunAsSubject; return getRunAsSubject;
} }
private static final Method getWSCredentialFromSubjectMethod() { private static final Method getGroupsForUserMethod() {
if (getWSCredentialFromSubject == null) { if (getGroupsForUser == null) {
getWSCredentialFromSubject = getMethod("com.ibm.ws.security.auth.SubjectHelper", "getWSCredentialFromSubject", getGroupsForUser = getMethod("com.ibm.websphere.security.UserRegistry", "getGroupsForUser", new String[] { "java.lang.String" });
new String[] { "javax.security.auth.Subject" }); }
} return getGroupsForUser;
return getWSCredentialFromSubject; }
}
private static final Method getGroupsForUserMethod() { private static final Method getSecurityNameMethod() {
if (getGroupsForUser == null) { if (getSecurityName == null) {
getGroupsForUser = getMethod("com.ibm.websphere.security.UserRegistry", "getGroupsForUser", new String[] { "java.lang.String" }); getSecurityName = getMethod("com.ibm.websphere.security.cred.WSCredential", "getSecurityName", new String[] {});
} }
return getGroupsForUser; return getSecurityName;
} }
private static final Method getSecurityNameMethod() { // SEC-803
if (getSecurityName == null) { private static final Class getWSCredentialClass() {
getSecurityName = getMethod("com.ibm.websphere.security.cred.WSCredential", "getSecurityName", new String[] {}); if (wsCredentialClass == null) {
} wsCredentialClass = getClass("com.ibm.websphere.security.cred.WSCredential");
return getSecurityName; }
} return wsCredentialClass;
}
private static final Class getClass(String className) {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
logger.error("Required class " + className + " not found");
throw new RuntimeException("Required class " + className + " not found",e);
}
}
} }