442495 - Bad Context ClassLoader in JSR356 WebSocket onOpen

+ Fixing onOpen context classloader to be that of the context
  that started the WebSocketUpgradeFilter (which will be the
  same as the WebAppContext in most cases)
This commit is contained in:
Joakim Erdfelt 2014-09-22 14:37:00 -07:00
parent d6082b2d65
commit 0dca1b0794
4 changed files with 114 additions and 85 deletions

View File

@ -131,105 +131,113 @@ public class WebSocketServerContainerInitializer implements ServletContainerInit
ServletContextHandler jettyContext = (ServletContextHandler)handler;
// Create the Jetty ServerContainer implementation
ServerContainer jettyContainer = configureContext(context, jettyContext);
// Store a reference to the ServerContainer per javax.websocket spec 1.0 final section 6.4 Programmatic Server Deployment
context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
if (LOG.isDebugEnabled())
ClassLoader old = Thread.currentThread().getContextClassLoader();
try
{
LOG.debug("Found {} classes",c.size());
}
Thread.currentThread().setContextClassLoader(context.getClassLoader());
// Now process the incoming classes
Set<Class<? extends Endpoint>> discoveredExtendedEndpoints = new HashSet<>();
Set<Class<?>> discoveredAnnotatedEndpoints = new HashSet<>();
Set<Class<? extends ServerApplicationConfig>> serverAppConfigs = new HashSet<>();
// Create the Jetty ServerContainer implementation
ServerContainer jettyContainer = configureContext(context,jettyContext);
filterClasses(c,discoveredExtendedEndpoints,discoveredAnnotatedEndpoints,serverAppConfigs);
// Store a reference to the ServerContainer per javax.websocket spec 1.0 final section 6.4 Programmatic Server Deployment
context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
if (LOG.isDebugEnabled())
{
LOG.debug("Discovered {} extends Endpoint classes",discoveredExtendedEndpoints.size());
LOG.debug("Discovered {} @ServerEndpoint classes",discoveredAnnotatedEndpoints.size());
LOG.debug("Discovered {} ServerApplicationConfig classes",serverAppConfigs.size());
}
// Process the server app configs to determine endpoint filtering
boolean wasFiltered = false;
Set<ServerEndpointConfig> deployableExtendedEndpointConfigs = new HashSet<>();
Set<Class<?>> deployableAnnotatedEndpoints = new HashSet<>();
for (Class<? extends ServerApplicationConfig> clazz : serverAppConfigs)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Found ServerApplicationConfig: {}",clazz);
LOG.debug("Found {} classes",c.size());
}
try
// Now process the incoming classes
Set<Class<? extends Endpoint>> discoveredExtendedEndpoints = new HashSet<>();
Set<Class<?>> discoveredAnnotatedEndpoints = new HashSet<>();
Set<Class<? extends ServerApplicationConfig>> serverAppConfigs = new HashSet<>();
filterClasses(c,discoveredExtendedEndpoints,discoveredAnnotatedEndpoints,serverAppConfigs);
if (LOG.isDebugEnabled())
{
ServerApplicationConfig config = clazz.newInstance();
LOG.debug("Discovered {} extends Endpoint classes",discoveredExtendedEndpoints.size());
LOG.debug("Discovered {} @ServerEndpoint classes",discoveredAnnotatedEndpoints.size());
LOG.debug("Discovered {} ServerApplicationConfig classes",serverAppConfigs.size());
}
Set<ServerEndpointConfig> seconfigs = config.getEndpointConfigs(discoveredExtendedEndpoints);
if (seconfigs != null)
// Process the server app configs to determine endpoint filtering
boolean wasFiltered = false;
Set<ServerEndpointConfig> deployableExtendedEndpointConfigs = new HashSet<>();
Set<Class<?>> deployableAnnotatedEndpoints = new HashSet<>();
for (Class<? extends ServerApplicationConfig> clazz : serverAppConfigs)
{
if (LOG.isDebugEnabled())
{
wasFiltered = true;
deployableExtendedEndpointConfigs.addAll(seconfigs);
LOG.debug("Found ServerApplicationConfig: {}",clazz);
}
Set<Class<?>> annotatedClasses = config.getAnnotatedEndpointClasses(discoveredAnnotatedEndpoints);
if (annotatedClasses != null)
try
{
wasFiltered = true;
deployableAnnotatedEndpoints.addAll(annotatedClasses);
ServerApplicationConfig config = clazz.newInstance();
Set<ServerEndpointConfig> seconfigs = config.getEndpointConfigs(discoveredExtendedEndpoints);
if (seconfigs != null)
{
wasFiltered = true;
deployableExtendedEndpointConfigs.addAll(seconfigs);
}
Set<Class<?>> annotatedClasses = config.getAnnotatedEndpointClasses(discoveredAnnotatedEndpoints);
if (annotatedClasses != null)
{
wasFiltered = true;
deployableAnnotatedEndpoints.addAll(annotatedClasses);
}
}
catch (InstantiationException | IllegalAccessException e)
{
throw new ServletException("Unable to instantiate: " + clazz.getName(),e);
}
}
catch (InstantiationException | IllegalAccessException e)
{
throw new ServletException("Unable to instantiate: " + clazz.getName(),e);
}
}
// Default behavior if nothing filtered
if (!wasFiltered)
{
deployableAnnotatedEndpoints.addAll(discoveredAnnotatedEndpoints);
// Note: it is impossible to determine path of "extends Endpoint" discovered classes
deployableExtendedEndpointConfigs = new HashSet<>();
}
// Default behavior if nothing filtered
if (!wasFiltered)
{
deployableAnnotatedEndpoints.addAll(discoveredAnnotatedEndpoints);
// Note: it is impossible to determine path of "extends Endpoint" discovered classes
deployableExtendedEndpointConfigs = new HashSet<>();
}
if (LOG.isDebugEnabled())
{
LOG.debug("Deploying {} ServerEndpointConfig(s)",deployableExtendedEndpointConfigs.size());
}
// Deploy what should be deployed.
for (ServerEndpointConfig config : deployableExtendedEndpointConfigs)
{
try
if (LOG.isDebugEnabled())
{
jettyContainer.addEndpoint(config);
LOG.debug("Deploying {} ServerEndpointConfig(s)",deployableExtendedEndpointConfigs.size());
}
catch (DeploymentException e)
// Deploy what should be deployed.
for (ServerEndpointConfig config : deployableExtendedEndpointConfigs)
{
throw new ServletException(e);
try
{
jettyContainer.addEndpoint(config);
}
catch (DeploymentException e)
{
throw new ServletException(e);
}
}
}
if (LOG.isDebugEnabled())
{
LOG.debug("Deploying {} @ServerEndpoint(s)",deployableAnnotatedEndpoints.size());
}
for (Class<?> annotatedClass : deployableAnnotatedEndpoints)
{
try
if (LOG.isDebugEnabled())
{
jettyContainer.addEndpoint(annotatedClass);
LOG.debug("Deploying {} @ServerEndpoint(s)",deployableAnnotatedEndpoints.size());
}
catch (DeploymentException e)
for (Class<?> annotatedClass : deployableAnnotatedEndpoints)
{
throw new ServletException(e);
try
{
jettyContainer.addEndpoint(annotatedClass);
}
catch (DeploymentException e)
{
throw new ServletException(e);
}
}
} finally {
Thread.currentThread().setContextClassLoader(old);
}
}

View File

@ -61,6 +61,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
private final LogicalConnection connection;
private final SessionListener[] sessionListeners;
private final Executor executor;
private ClassLoader classLoader;
private ExtensionFactory extensionFactory;
private String protocolVersion;
private Map<String, String[]> parameterMap = new HashMap<>();
@ -78,6 +79,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
throw new RuntimeException("Request URI cannot be null");
}
this.classLoader = Thread.currentThread().getContextClassLoader();
this.requestURI = requestURI;
this.websocket = websocket;
this.connection = connection;
@ -182,6 +184,11 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
{
return this.connection.getBufferPool();
}
public ClassLoader getClassLoader()
{
return this.getClass().getClassLoader();
}
public LogicalConnection getConnection()
{
@ -393,15 +400,17 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
// already opened
return;
}
ClassLoader old = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(classLoader);
// Upgrade success
connection.getIOState().onConnected();
// Upgrade success
connection.getIOState().onConnected();
// Connect remote
remote = new WebSocketRemoteEndpoint(connection,outgoingHandler,getBatchMode());
// Connect remote
remote = new WebSocketRemoteEndpoint(connection,outgoingHandler,getBatchMode());
try
{
// Open WebSocket
websocket.openSession(this);
@ -425,6 +434,10 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
close(statusCode,t.getMessage());
}
finally
{
Thread.currentThread().setContextClassLoader(old);
}
}
public void setExtensionFactory(ExtensionFactory extensionFactory)
@ -506,4 +519,5 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Inc
builder.append("]");
return builder.toString();
}
}

View File

@ -30,6 +30,7 @@ import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -70,6 +71,7 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
{
private static final Logger LOG = Log.getLogger(WebSocketServerFactory.class);
private final ClassLoader contextClassloader;
private final Map<Integer, WebSocketHandshake> handshakes = new HashMap<>();
/**
* Have the factory maintain 1 and only 1 scheduler. All connections share this scheduler.
@ -106,6 +108,8 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
addBean(scheduler);
addBean(bufferPool);
this.contextClassloader = Thread.currentThread().getContextClassLoader();
this.registeredSocketClasses = new ArrayList<>();
@ -151,8 +155,10 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
@Override
public boolean acceptWebSocket(WebSocketCreator creator, HttpServletRequest request, HttpServletResponse response) throws IOException
{
ClassLoader old = Thread.currentThread().getContextClassLoader();
try
{
Thread.currentThread().setContextClassLoader(contextClassloader);
ServletUpgradeRequest sockreq = new ServletUpgradeRequest(request);
ServletUpgradeResponse sockresp = new ServletUpgradeResponse(response);
@ -181,6 +187,10 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
catch (URISyntaxException e)
{
throw new IOException("Unable to accept websocket due to mangled URI", e);
}
finally
{
Thread.currentThread().setContextClassLoader(old);
}
}

View File

@ -91,10 +91,7 @@ public class WebSocketUpgradeFilter extends ContainerLifeCycle implements Filter
String pathSpec = "/*";
EnumSet<DispatcherType> dispatcherTypes = EnumSet.of(DispatcherType.REQUEST);
boolean isMatchAfter = false;
String urlPatterns[] =
{
pathSpec
};
String urlPatterns[] = { pathSpec };
FilterRegistration.Dynamic dyn = context.addFilter(name,filter);
dyn.addMappingForUrlPatterns(dispatcherTypes,isMatchAfter,urlPatterns);