diff --git a/activemq-client/src/main/java/org/apache/activemq/util/ClassLoadingAwareObjectInputStream.java b/activemq-client/src/main/java/org/apache/activemq/util/ClassLoadingAwareObjectInputStream.java index ccee17d433..ca4d8a356a 100644 --- a/activemq-client/src/main/java/org/apache/activemq/util/ClassLoadingAwareObjectInputStream.java +++ b/activemq-client/src/main/java/org/apache/activemq/util/ClassLoadingAwareObjectInputStream.java @@ -21,19 +21,16 @@ import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; import java.lang.reflect.Proxy; -import java.util.HashMap; -@SuppressWarnings("rawtypes") +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class ClassLoadingAwareObjectInputStream extends ObjectInputStream { + private static final Logger LOG = LoggerFactory.getLogger(ClassLoadingAwareObjectInputStream.class); private static final ClassLoader FALLBACK_CLASS_LOADER = ClassLoadingAwareObjectInputStream.class.getClassLoader(); - /** - * Maps primitive type names to corresponding class objects. - */ - private static final HashMap primClasses = new HashMap(8, 1.0F); - private final ClassLoader inLoader; public ClassLoadingAwareObjectInputStream(InputStream in) throws IOException { @@ -72,31 +69,84 @@ public class ClassLoadingAwareObjectInputStream extends ObjectInputStream { } private Class load(String className, ClassLoader... cl) throws ClassNotFoundException { + // check for simple types first + final Class clazz = loadSimpleType(className); + if (clazz != null) { + LOG.trace("Loaded class: {} as simple type -> ", className, clazz); + return clazz; + } + + // try the different class loaders for (ClassLoader loader : cl) { + LOG.trace("Attempting to load class: {} using classloader: {}", className, cl); try { - return Class.forName(className, false, loader); + Class answer = Class.forName(className, false, loader); + if (LOG.isTraceEnabled()) { + LOG.trace("Loaded class: {} using classloader: {} -> ", new Object[]{className, cl, answer}); + } + return answer; } catch (ClassNotFoundException e) { + LOG.trace("Class not found: {} using classloader: {}", className, cl); // ignore } } - // fallback - final Class clazz = (Class) primClasses.get(className); - if (clazz != null) { - return clazz; - } else { - return Class.forName(className, false, FALLBACK_CLASS_LOADER); - } + + // and then the fallback class loader + return Class.forName(className, false, FALLBACK_CLASS_LOADER); } - static { - primClasses.put("boolean", boolean.class); - primClasses.put("byte", byte.class); - primClasses.put("char", char.class); - primClasses.put("short", short.class); - primClasses.put("int", int.class); - primClasses.put("long", long.class); - primClasses.put("float", float.class); - primClasses.put("double", double.class); - primClasses.put("void", void.class); + /** + * Load a simple type + * + * @param name the name of the class to load + * @return the class or null if it could not be loaded + */ + public static Class loadSimpleType(String name) { + // code from ObjectHelper.loadSimpleType in Apache Camel + + // special for byte[] or Object[] as its common to use + if ("java.lang.byte[]".equals(name) || "byte[]".equals(name)) { + return byte[].class; + } else if ("java.lang.Byte[]".equals(name) || "Byte[]".equals(name)) { + return Byte[].class; + } else if ("java.lang.Object[]".equals(name) || "Object[]".equals(name)) { + return Object[].class; + } else if ("java.lang.String[]".equals(name) || "String[]".equals(name)) { + return String[].class; + // and these is common as well + } else if ("java.lang.String".equals(name) || "String".equals(name)) { + return String.class; + } else if ("java.lang.Boolean".equals(name) || "Boolean".equals(name)) { + return Boolean.class; + } else if ("boolean".equals(name)) { + return boolean.class; + } else if ("java.lang.Integer".equals(name) || "Integer".equals(name)) { + return Integer.class; + } else if ("int".equals(name)) { + return int.class; + } else if ("java.lang.Long".equals(name) || "Long".equals(name)) { + return Long.class; + } else if ("long".equals(name)) { + return long.class; + } else if ("java.lang.Short".equals(name) || "Short".equals(name)) { + return Short.class; + } else if ("short".equals(name)) { + return short.class; + } else if ("java.lang.Byte".equals(name) || "Byte".equals(name)) { + return Byte.class; + } else if ("byte".equals(name)) { + return byte.class; + } else if ("java.lang.Float".equals(name) || "Float".equals(name)) { + return Float.class; + } else if ("float".equals(name)) { + return float.class; + } else if ("java.lang.Double".equals(name) || "Double".equals(name)) { + return Double.class; + } else if ("double".equals(name)) { + return double.class; + } + + return null; } + }