442495 - Bad Context ClassLoader in JSR356 WebSocket onOpen

+ Fixing onOpen context classloader to be that of the context
  that started the WebSocketUpgradeFilter (which will be the
  same as the WebAppContext in most cases)
This commit is contained in:
Joakim Erdfelt 2014-09-22 14:37:00 -07:00 committed by Simone Bordet
parent 901707b894
commit 0bf68a07ae
4 changed files with 114 additions and 85 deletions

View File

@ -131,105 +131,113 @@ public class WebSocketServerContainerInitializer implements ServletContainerInit
ServletContextHandler jettyContext = (ServletContextHandler)handler; ServletContextHandler jettyContext = (ServletContextHandler)handler;
// Create the Jetty ServerContainer implementation ClassLoader old = Thread.currentThread().getContextClassLoader();
ServerContainer jettyContainer = configureContext(context, jettyContext); try
// Store a reference to the ServerContainer per javax.websocket spec 1.0 final section 6.4 Programmatic Server Deployment
context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
if (LOG.isDebugEnabled())
{ {
LOG.debug("Found {} classes",c.size()); Thread.currentThread().setContextClassLoader(context.getClassLoader());
}
// Now process the incoming classes // Create the Jetty ServerContainer implementation
Set<Class<? extends Endpoint>> discoveredExtendedEndpoints = new HashSet<>(); ServerContainer jettyContainer = configureContext(context,jettyContext);
Set<Class<?>> discoveredAnnotatedEndpoints = new HashSet<>();
Set<Class<? extends ServerApplicationConfig>> serverAppConfigs = new HashSet<>();
filterClasses(c,discoveredExtendedEndpoints,discoveredAnnotatedEndpoints,serverAppConfigs); // Store a reference to the ServerContainer per javax.websocket spec 1.0 final section 6.4 Programmatic Server Deployment
context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
if (LOG.isDebugEnabled())
{
LOG.debug("Discovered {} extends Endpoint classes",discoveredExtendedEndpoints.size());
LOG.debug("Discovered {} @ServerEndpoint classes",discoveredAnnotatedEndpoints.size());
LOG.debug("Discovered {} ServerApplicationConfig classes",serverAppConfigs.size());
}
// Process the server app configs to determine endpoint filtering
boolean wasFiltered = false;
Set<ServerEndpointConfig> deployableExtendedEndpointConfigs = new HashSet<>();
Set<Class<?>> deployableAnnotatedEndpoints = new HashSet<>();
for (Class<? extends ServerApplicationConfig> clazz : serverAppConfigs)
{
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
{ {
LOG.debug("Found ServerApplicationConfig: {}",clazz); LOG.debug("Found {} classes",c.size());
} }
try
// Now process the incoming classes
Set<Class<? extends Endpoint>> discoveredExtendedEndpoints = new HashSet<>();
Set<Class<?>> discoveredAnnotatedEndpoints = new HashSet<>();
Set<Class<? extends ServerApplicationConfig>> serverAppConfigs = new HashSet<>();
filterClasses(c,discoveredExtendedEndpoints,discoveredAnnotatedEndpoints,serverAppConfigs);
if (LOG.isDebugEnabled())
{ {
ServerApplicationConfig config = clazz.newInstance(); LOG.debug("Discovered {} extends Endpoint classes",discoveredExtendedEndpoints.size());
LOG.debug("Discovered {} @ServerEndpoint classes",discoveredAnnotatedEndpoints.size());
LOG.debug("Discovered {} ServerApplicationConfig classes",serverAppConfigs.size());
}
Set<ServerEndpointConfig> seconfigs = config.getEndpointConfigs(discoveredExtendedEndpoints); // Process the server app configs to determine endpoint filtering
if (seconfigs != null) boolean wasFiltered = false;
Set<ServerEndpointConfig> deployableExtendedEndpointConfigs = new HashSet<>();
Set<Class<?>> deployableAnnotatedEndpoints = new HashSet<>();
for (Class<? extends ServerApplicationConfig> clazz : serverAppConfigs)
{
if (LOG.isDebugEnabled())
{ {
wasFiltered = true; LOG.debug("Found ServerApplicationConfig: {}",clazz);
deployableExtendedEndpointConfigs.addAll(seconfigs);
} }
try
Set<Class<?>> annotatedClasses = config.getAnnotatedEndpointClasses(discoveredAnnotatedEndpoints);
if (annotatedClasses != null)
{ {
wasFiltered = true; ServerApplicationConfig config = clazz.newInstance();
deployableAnnotatedEndpoints.addAll(annotatedClasses);
Set<ServerEndpointConfig> seconfigs = config.getEndpointConfigs(discoveredExtendedEndpoints);
if (seconfigs != null)
{
wasFiltered = true;
deployableExtendedEndpointConfigs.addAll(seconfigs);
}
Set<Class<?>> annotatedClasses = config.getAnnotatedEndpointClasses(discoveredAnnotatedEndpoints);
if (annotatedClasses != null)
{
wasFiltered = true;
deployableAnnotatedEndpoints.addAll(annotatedClasses);
}
}
catch (InstantiationException | IllegalAccessException e)
{
throw new ServletException("Unable to instantiate: " + clazz.getName(),e);
} }
} }
catch (InstantiationException | IllegalAccessException e)
{
throw new ServletException("Unable to instantiate: " + clazz.getName(),e);
}
}
// Default behavior if nothing filtered // Default behavior if nothing filtered
if (!wasFiltered) if (!wasFiltered)
{ {
deployableAnnotatedEndpoints.addAll(discoveredAnnotatedEndpoints); deployableAnnotatedEndpoints.addAll(discoveredAnnotatedEndpoints);
// Note: it is impossible to determine path of "extends Endpoint" discovered classes // Note: it is impossible to determine path of "extends Endpoint" discovered classes
deployableExtendedEndpointConfigs = new HashSet<>(); deployableExtendedEndpointConfigs = new HashSet<>();
} }
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
{
LOG.debug("Deploying {} ServerEndpointConfig(s)",deployableExtendedEndpointConfigs.size());
}
// Deploy what should be deployed.
for (ServerEndpointConfig config : deployableExtendedEndpointConfigs)
{
try
{ {
jettyContainer.addEndpoint(config); LOG.debug("Deploying {} ServerEndpointConfig(s)",deployableExtendedEndpointConfigs.size());
} }
catch (DeploymentException e) // Deploy what should be deployed.
for (ServerEndpointConfig config : deployableExtendedEndpointConfigs)
{ {
throw new ServletException(e); try
{
jettyContainer.addEndpoint(config);
}
catch (DeploymentException e)
{
throw new ServletException(e);
}
} }
}
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
{
LOG.debug("Deploying {} @ServerEndpoint(s)",deployableAnnotatedEndpoints.size());
}
for (Class<?> annotatedClass : deployableAnnotatedEndpoints)
{
try
{ {
jettyContainer.addEndpoint(annotatedClass); LOG.debug("Deploying {} @ServerEndpoint(s)",deployableAnnotatedEndpoints.size());
} }
catch (DeploymentException e) for (Class<?> annotatedClass : deployableAnnotatedEndpoints)
{ {
throw new ServletException(e); try
{
jettyContainer.addEndpoint(annotatedClass);
}
catch (DeploymentException e)
{
throw new ServletException(e);
}
} }
} finally {
Thread.currentThread().setContextClassLoader(old);
} }
} }

View File

@ -61,6 +61,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
private final LogicalConnection connection; private final LogicalConnection connection;
private final SessionListener[] sessionListeners; private final SessionListener[] sessionListeners;
private final Executor executor; private final Executor executor;
private ClassLoader classLoader;
private ExtensionFactory extensionFactory; private ExtensionFactory extensionFactory;
private String protocolVersion; private String protocolVersion;
private Map<String, String[]> parameterMap = new HashMap<>(); private Map<String, String[]> parameterMap = new HashMap<>();
@ -78,6 +79,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
throw new RuntimeException("Request URI cannot be null"); throw new RuntimeException("Request URI cannot be null");
} }
this.classLoader = Thread.currentThread().getContextClassLoader();
this.requestURI = requestURI; this.requestURI = requestURI;
this.websocket = websocket; this.websocket = websocket;
this.connection = connection; this.connection = connection;
@ -182,6 +184,11 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
{ {
return this.connection.getBufferPool(); return this.connection.getBufferPool();
} }
public ClassLoader getClassLoader()
{
return this.getClass().getClassLoader();
}
public LogicalConnection getConnection() public LogicalConnection getConnection()
{ {
@ -393,15 +400,17 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
// already opened // already opened
return; return;
} }
ClassLoader old = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(classLoader);
// Upgrade success // Upgrade success
connection.getIOState().onConnected(); connection.getIOState().onConnected();
// Connect remote
remote = new WebSocketRemoteEndpoint(connection,outgoingHandler,getBatchMode());
// Connect remote
remote = new WebSocketRemoteEndpoint(connection,outgoingHandler,getBatchMode());
try
{
// Open WebSocket // Open WebSocket
websocket.openSession(this); websocket.openSession(this);
@ -425,6 +434,10 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
close(statusCode,t.getMessage()); close(statusCode,t.getMessage());
} }
finally
{
Thread.currentThread().setContextClassLoader(old);
}
} }
public void setExtensionFactory(ExtensionFactory extensionFactory) public void setExtensionFactory(ExtensionFactory extensionFactory)
@ -506,4 +519,5 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
builder.append("]"); builder.append("]");
return builder.toString(); return builder.toString();
} }
} }

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -70,6 +71,7 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
{ {
private static final Logger LOG = Log.getLogger(WebSocketServerFactory.class); private static final Logger LOG = Log.getLogger(WebSocketServerFactory.class);
private final ClassLoader contextClassloader;
private final Map<Integer, WebSocketHandshake> handshakes = new HashMap<>(); private final Map<Integer, WebSocketHandshake> handshakes = new HashMap<>();
/** /**
* Have the factory maintain 1 and only 1 scheduler. All connections share this scheduler. * Have the factory maintain 1 and only 1 scheduler. All connections share this scheduler.
@ -106,6 +108,8 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
addBean(scheduler); addBean(scheduler);
addBean(bufferPool); addBean(bufferPool);
this.contextClassloader = Thread.currentThread().getContextClassLoader();
this.registeredSocketClasses = new ArrayList<>(); this.registeredSocketClasses = new ArrayList<>();
@ -151,8 +155,10 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
@Override @Override
public boolean acceptWebSocket(WebSocketCreator creator, HttpServletRequest request, HttpServletResponse response) throws IOException public boolean acceptWebSocket(WebSocketCreator creator, HttpServletRequest request, HttpServletResponse response) throws IOException
{ {
ClassLoader old = Thread.currentThread().getContextClassLoader();
try try
{ {
Thread.currentThread().setContextClassLoader(contextClassloader);
ServletUpgradeRequest sockreq = new ServletUpgradeRequest(request); ServletUpgradeRequest sockreq = new ServletUpgradeRequest(request);
ServletUpgradeResponse sockresp = new ServletUpgradeResponse(response); ServletUpgradeResponse sockresp = new ServletUpgradeResponse(response);
@ -181,6 +187,10 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
catch (URISyntaxException e) catch (URISyntaxException e)
{ {
throw new IOException("Unable to accept websocket due to mangled URI", e); throw new IOException("Unable to accept websocket due to mangled URI", e);
}
finally
{
Thread.currentThread().setContextClassLoader(old);
} }
} }

View File

@ -91,10 +91,7 @@ public class WebSocketUpgradeFilter extends ContainerLifeCycle implements Filter
String pathSpec = "/*"; String pathSpec = "/*";
EnumSet<DispatcherType> dispatcherTypes = EnumSet.of(DispatcherType.REQUEST); EnumSet<DispatcherType> dispatcherTypes = EnumSet.of(DispatcherType.REQUEST);
boolean isMatchAfter = false; boolean isMatchAfter = false;
String urlPatterns[] = String urlPatterns[] = { pathSpec };
{
pathSpec
};
FilterRegistration.Dynamic dyn = context.addFilter(name,filter); FilterRegistration.Dynamic dyn = context.addFilter(name,filter);
dyn.addMappingForUrlPatterns(dispatcherTypes,isMatchAfter,urlPatterns); dyn.addMappingForUrlPatterns(dispatcherTypes,isMatchAfter,urlPatterns);