Fixes #4408 - fix issues with javax metadata and decoders (#4452)

Remove the metadata cache per endpoint class, to allow deployment of the same class with different EndpointConfig settings.

JavaxServerFrameHandlerFactory now matches for decoders before matching for basic onMessage signatures.

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan 2020-01-13 13:49:10 +11:00 committed by GitHub
parent f4fc78ac66
commit 5bd4cee7c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 428 additions and 326 deletions

View File

@ -46,7 +46,7 @@ public class JavaxWebSocketClientFrameHandlerFactory extends JavaxWebSocketFrame
} }
@Override @Override
public JavaxWebSocketFrameHandlerMetadata createMetadata(Class<?> endpointClass, EndpointConfig endpointConfig) public JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig)
{ {
if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass))
{ {

View File

@ -22,7 +22,6 @@ import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -93,18 +92,17 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
private MethodHandle errorHandle; private MethodHandle errorHandle;
private JavaxWebSocketFrameHandlerMetadata.MessageMetadata textMetadata; private JavaxWebSocketFrameHandlerMetadata.MessageMetadata textMetadata;
private JavaxWebSocketFrameHandlerMetadata.MessageMetadata binaryMetadata; private JavaxWebSocketFrameHandlerMetadata.MessageMetadata binaryMetadata;
// TODO: need pingHandle ?
private MethodHandle pongHandle; private MethodHandle pongHandle;
private UpgradeRequest upgradeRequest; private UpgradeRequest upgradeRequest;
private UpgradeResponse upgradeResponse; private UpgradeResponse upgradeResponse;
private EndpointConfig endpointConfig; private EndpointConfig endpointConfig;
private final Map<Byte, RegisteredMessageHandler> messageHandlerMap = new HashMap<>();
private MessageSink textSink; private MessageSink textSink;
private MessageSink binarySink; private MessageSink binarySink;
private MessageSink activeMessageSink; private MessageSink activeMessageSink;
private JavaxWebSocketSession session; private JavaxWebSocketSession session;
private Map<Byte, RegisteredMessageHandler> messageHandlerMap;
private CoreSession coreSession; private CoreSession coreSession;
protected byte dataType = OpCode.UNDEFINED; protected byte dataType = OpCode.UNDEFINED;
@ -136,7 +134,6 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
this.pongHandle = pongHandle; this.pongHandle = pongHandle;
this.endpointConfig = endpointConfig; this.endpointConfig = endpointConfig;
this.messageHandlerMap = new HashMap<>();
} }
public Object getEndpoint() public Object getEndpoint()
@ -164,16 +161,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
// Rewire EndpointConfig to call CoreSession setters if Jetty specific properties are set. // Rewire EndpointConfig to call CoreSession setters if Jetty specific properties are set.
endpointConfig = getWrappedEndpointConfig(); endpointConfig = getWrappedEndpointConfig();
session = new JavaxWebSocketSession(container, coreSession, this, endpointConfig); session = new JavaxWebSocketSession(container, coreSession, this, endpointConfig);
openHandle = InvokerUtils.bindTo(openHandle, session, endpointConfig); openHandle = InvokerUtils.bindTo(openHandle, session, endpointConfig);
closeHandle = InvokerUtils.bindTo(closeHandle, session); closeHandle = InvokerUtils.bindTo(closeHandle, session);
errorHandle = InvokerUtils.bindTo(errorHandle, session); errorHandle = InvokerUtils.bindTo(errorHandle, session);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualTextMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(textMetadata);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualBinaryMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(binaryMetadata);
pongHandle = InvokerUtils.bindTo(pongHandle, session); pongHandle = InvokerUtils.bindTo(pongHandle, session);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualTextMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(textMetadata);
if (actualTextMetadata != null) if (actualTextMetadata != null)
{ {
if (actualTextMetadata.isMaxMessageSizeSet()) if (actualTextMetadata.isMaxMessageSizeSet())
@ -182,10 +175,10 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
actualTextMetadata.handle = InvokerUtils.bindTo(actualTextMetadata.handle, endpointInstance, endpointConfig, session); actualTextMetadata.handle = InvokerUtils.bindTo(actualTextMetadata.handle, endpointInstance, endpointConfig, session);
actualTextMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualTextMetadata.handle, session); actualTextMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualTextMetadata.handle, session);
textSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualTextMetadata); textSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualTextMetadata);
textMetadata = actualTextMetadata; textMetadata = actualTextMetadata;
} }
JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualBinaryMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(binaryMetadata);
if (actualBinaryMetadata != null) if (actualBinaryMetadata != null)
{ {
if (actualBinaryMetadata.isMaxMessageSizeSet()) if (actualBinaryMetadata.isMaxMessageSizeSet())
@ -194,7 +187,6 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
actualBinaryMetadata.handle = InvokerUtils.bindTo(actualBinaryMetadata.handle, endpointInstance, endpointConfig, session); actualBinaryMetadata.handle = InvokerUtils.bindTo(actualBinaryMetadata.handle, endpointInstance, endpointConfig, session);
actualBinaryMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualBinaryMetadata.handle, session); actualBinaryMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualBinaryMetadata.handle, session);
binarySink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualBinaryMetadata); binarySink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualBinaryMetadata);
binaryMetadata = actualBinaryMetadata; binaryMetadata = actualBinaryMetadata;
} }
@ -327,14 +319,9 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
public Set<MessageHandler> getMessageHandlers() public Set<MessageHandler> getMessageHandlers()
{ {
if (messageHandlerMap.isEmpty()) return messageHandlerMap.values().stream()
{ .map(RegisteredMessageHandler::getMessageHandler)
return Collections.emptySet(); .collect(Collectors.toUnmodifiableSet());
}
return Collections.unmodifiableSet(messageHandlerMap.values().stream()
.map((rh) -> rh.getMessageHandler())
.collect(Collectors.toSet()));
} }
public Map<Byte, RegisteredMessageHandler> getMessageHandlerMap() public Map<Byte, RegisteredMessageHandler> getMessageHandlerMap()
@ -368,7 +355,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
// TODO: move methodhandle lookup to container? // TODO: move methodhandle lookup to container?
MethodHandles.Lookup lookup = MethodHandles.publicLookup(); MethodHandles.Lookup lookup = MethodHandles.publicLookup();
MethodHandle partialMessageHandler = lookup MethodHandle partialMessageHandler = lookup
.findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(Void.TYPE, Object.class, Boolean.TYPE)); .findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class));
partialMessageHandler = partialMessageHandler.bindTo(handler); partialMessageHandler = partialMessageHandler.bindTo(handler);
// MessageHandler.Partial has no decoder support! // MessageHandler.Partial has no decoder support!
@ -423,9 +410,9 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
{ {
try try
{ {
// TODO: move methodhandle lookup to container? // TODO: move MethodHandle lookup to container?
MethodHandles.Lookup lookup = MethodHandles.publicLookup(); MethodHandles.Lookup lookup = MethodHandles.publicLookup();
MethodHandle wholeMsgMethodHandle = lookup.findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(Void.TYPE, Object.class)); MethodHandle wholeMsgMethodHandle = lookup.findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(void.class, Object.class));
wholeMsgMethodHandle = wholeMsgMethodHandle.bindTo(handler); wholeMsgMethodHandle = wholeMsgMethodHandle.bindTo(handler);
if (PongMessage.class.isAssignableFrom(clazz)) if (PongMessage.class.isAssignableFrom(clazz))

View File

@ -31,7 +31,7 @@ import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function;
import javax.websocket.CloseReason; import javax.websocket.CloseReason;
import javax.websocket.Decoder; import javax.websocket.Decoder;
import javax.websocket.EndpointConfig; import javax.websocket.EndpointConfig;
@ -68,12 +68,61 @@ public abstract class JavaxWebSocketFrameHandlerFactory
{ {
private static final MethodHandle FILTER_RETURN_TYPE_METHOD; private static final MethodHandle FILTER_RETURN_TYPE_METHOD;
// The different kind of @OnMessage method parameter signatures expected.
private static final InvokerUtils.Arg[] textCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required()
};
private static final InvokerUtils.Arg[] textPartialCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
private static final InvokerUtils.Arg[] binaryBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required()
};
private static final InvokerUtils.Arg[] binaryPartialBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
private static final InvokerUtils.Arg[] binaryArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required()
};
private static final InvokerUtils.Arg[] binaryPartialArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
private static final InvokerUtils.Arg[] inputStreamCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(InputStream.class).required()
};
private static final InvokerUtils.Arg[] readerCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(Reader.class).required()
};
private static final InvokerUtils.Arg[] pongCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(PongMessage.class).required()
};
static static
{ {
try try
{ {
FILTER_RETURN_TYPE_METHOD = MethodHandles.lookup() FILTER_RETURN_TYPE_METHOD = MethodHandles.lookup()
.findVirtual(JavaxWebSocketSession.class, "filterReturnType", MethodType.methodType(Void.TYPE, Object.class)); .findVirtual(JavaxWebSocketSession.class, "filterReturnType", MethodType.methodType(void.class, Object.class));
} }
catch (Throwable e) catch (Throwable e)
{ {
@ -83,7 +132,6 @@ public abstract class JavaxWebSocketFrameHandlerFactory
protected final JavaxWebSocketContainer container; protected final JavaxWebSocketContainer container;
protected final InvokerUtils.ParamIdentifier paramIdentifier; protected final InvokerUtils.ParamIdentifier paramIdentifier;
private Map<Class<?>, JavaxWebSocketFrameHandlerMetadata> metadataMap = new ConcurrentHashMap<>();
public JavaxWebSocketFrameHandlerFactory(JavaxWebSocketContainer container, InvokerUtils.ParamIdentifier paramIdentifier) public JavaxWebSocketFrameHandlerFactory(JavaxWebSocketContainer container, InvokerUtils.ParamIdentifier paramIdentifier)
{ {
@ -91,23 +139,10 @@ public abstract class JavaxWebSocketFrameHandlerFactory
this.paramIdentifier = paramIdentifier == null ? InvokerUtils.PARAM_IDENTITY : paramIdentifier; this.paramIdentifier = paramIdentifier == null ? InvokerUtils.PARAM_IDENTITY : paramIdentifier;
} }
public JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig) public abstract JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig);
{
JavaxWebSocketFrameHandlerMetadata metadata = metadataMap.get(endpointClass);
if (metadata == null)
{
metadata = createMetadata(endpointClass, endpointConfig);
metadataMap.put(endpointClass, metadata);
}
return metadata;
}
public abstract EndpointConfig newDefaultEndpointConfig(Class<?> endpointClass, String path); public abstract EndpointConfig newDefaultEndpointConfig(Class<?> endpointClass, String path);
public abstract JavaxWebSocketFrameHandlerMetadata createMetadata(Class<?> endpointClass, EndpointConfig endpointConfig);
public JavaxWebSocketFrameHandler newJavaxWebSocketFrameHandler(Object endpointInstance, UpgradeRequest upgradeRequest) public JavaxWebSocketFrameHandler newJavaxWebSocketFrameHandler(Object endpointInstance, UpgradeRequest upgradeRequest)
{ {
Object endpoint; Object endpoint;
@ -161,15 +196,13 @@ public abstract class JavaxWebSocketFrameHandlerFactory
errorHandle = InvokerUtils.bindTo(errorHandle, endpoint); errorHandle = InvokerUtils.bindTo(errorHandle, endpoint);
pongHandle = InvokerUtils.bindTo(pongHandle, endpoint); pongHandle = InvokerUtils.bindTo(pongHandle, endpoint);
JavaxWebSocketFrameHandler frameHandler = new JavaxWebSocketFrameHandler( return new JavaxWebSocketFrameHandler(
container, container,
endpoint, endpoint,
openHandle, closeHandle, errorHandle, openHandle, closeHandle, errorHandle,
textMetadata, binaryMetadata, textMetadata, binaryMetadata,
pongHandle, pongHandle,
config); config);
return frameHandler;
} }
/** /**
@ -318,7 +351,7 @@ public abstract class JavaxWebSocketFrameHandlerFactory
} }
} }
public static MethodHandle wrapNonVoidReturnType(MethodHandle handle, JavaxWebSocketSession session) throws NoSuchMethodException, IllegalAccessException public static MethodHandle wrapNonVoidReturnType(MethodHandle handle, JavaxWebSocketSession session)
{ {
if (handle == null) if (handle == null)
return null; return null;
@ -400,6 +433,7 @@ public abstract class JavaxWebSocketFrameHandlerFactory
.mutatedInvoker(endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CLOSE_REASON); .mutatedInvoker(endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CLOSE_REASON);
metadata.setCloseHandler(methodHandle, onmethod); metadata.setCloseHandler(methodHandle, onmethod);
} }
// OnError [0..1] // OnError [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnError.class); onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnError.class);
if (onmethod != null) if (onmethod != null)
@ -416,293 +450,32 @@ public abstract class JavaxWebSocketFrameHandlerFactory
Method[] onMessages = ReflectUtils.findAnnotatedMethods(endpointClass, OnMessage.class); Method[] onMessages = ReflectUtils.findAnnotatedMethods(endpointClass, OnMessage.class);
if (onMessages != null && onMessages.length > 0) if (onMessages != null && onMessages.length > 0)
{ {
// The different kind of @OnMessage method parameter signatures expected
InvokerUtils.Arg[] textCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required()
};
InvokerUtils.Arg[] textPartialCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
InvokerUtils.Arg[] binaryBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required()
};
InvokerUtils.Arg[] binaryPartialBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
InvokerUtils.Arg[] binaryArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required()
};
InvokerUtils.Arg[] binaryPartialArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required(),
new InvokerUtils.Arg(boolean.class).required()
};
InvokerUtils.Arg[] inputStreamCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(InputStream.class).required()
};
InvokerUtils.Arg[] readerCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(Reader.class).required()
};
InvokerUtils.Arg[] pongCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(PongMessage.class).required()
};
List<DecodedArgs> decodedTextCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedTextStreamCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedBinaryCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedBinaryStreamCallingArgs = new ArrayList<>();
for (AvailableDecoders.RegisteredDecoder decoder : metadata.getAvailableDecoders())
{
if (decoder.implementsInterface(Decoder.Text.class))
{
decodedTextCallingArgs.add(
new DecodedArgs(decoder,
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(decoder.objectType).required()
));
}
if (decoder.implementsInterface(Decoder.TextStream.class))
{
decodedTextStreamCallingArgs.add(
new DecodedArgs(decoder,
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(decoder.objectType).required()
));
}
if (decoder.implementsInterface(Decoder.Binary.class))
{
decodedBinaryCallingArgs.add(
new DecodedArgs(decoder,
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(decoder.objectType).required()
));
}
if (decoder.implementsInterface(Decoder.BinaryStream.class))
{
decodedBinaryStreamCallingArgs.add(
new DecodedArgs(decoder,
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(decoder.objectType).required()
));
}
}
onmessageloop:
for (Method onMsg : onMessages) for (Method onMsg : onMessages)
{ {
assertSignatureValid(endpointClass, onMsg, OnMessage.class); assertSignatureValid(endpointClass, onMsg, OnMessage.class);
OnMessage onMessageAnno = onMsg.getAnnotation(OnMessage.class);
MessageMetadata msgMetadata = new MessageMetadata(); MessageMetadata msgMetadata = new MessageMetadata();
OnMessage onMessageAnno = onMsg.getAnnotation(OnMessage.class);
if (onMessageAnno.maxMessageSize() > Integer.MAX_VALUE) if (onMessageAnno.maxMessageSize() > Integer.MAX_VALUE)
{ {
throw new InvalidWebSocketException( throw new InvalidWebSocketException(String.format("Value too large: %s#%s - @OnMessage.maxMessageSize=%,d > Integer.MAX_VALUE",
String.format("Value too large: %s#%s - @OnMessage.maxMessageSize=%,d > Integer.MAX_VALUE",
endpointClass.getName(), onMsg.getName(), onMessageAnno.maxMessageSize())); endpointClass.getName(), onMsg.getName(), onMessageAnno.maxMessageSize()));
} }
msgMetadata.maxMessageSize = (int)onMessageAnno.maxMessageSize(); msgMetadata.maxMessageSize = (int)onMessageAnno.maxMessageSize();
MethodHandle methodHandle = InvokerUtils // Function to search for matching MethodHandle for the endpointClass given a signature.
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), textCallingArgs); Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle = (signature) ->
if (methodHandle != null) InvokerUtils.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), signature);
{
// Whole Text Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = StringMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils // Try to match from available decoders.
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), textPartialCallingArgs); if (matchDecoders(onMsg, metadata, msgMetadata, getMethodHandle))
if (methodHandle != null) continue;
{
// Partial Text Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = PartialStringMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils // No decoders matched try basic signatures to call onMessage directly.
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), binaryBufferCallingArgs); if (matchOnMessage(onMsg, metadata, msgMetadata, getMethodHandle))
if (methodHandle != null) continue;
{
// Whole ByteBuffer Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = ByteBufferMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils // Not a valid @OnMessage declaration signature.
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), binaryPartialBufferCallingArgs);
if (methodHandle != null)
{
// Partial ByteBuffer Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = PartialByteBufferMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), binaryArrayCallingArgs);
if (methodHandle != null)
{
// Whole byte[] Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = ByteArrayMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), binaryPartialArrayCallingArgs);
if (methodHandle != null)
{
// Partial byte[] Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = PartialByteArrayMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), inputStreamCallingArgs);
if (methodHandle != null)
{
// InputStream Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = InputStreamMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), readerCallingArgs);
if (methodHandle != null)
{
// Reader Text Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = ReaderMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
// == Decoders ==
// Decoder.Text
for (DecodedArgs decodedArgs : decodedTextCallingArgs)
{
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), decodedArgs.args);
if (methodHandle != null)
{
// Decoded Text Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = DecodedTextMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setTextMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
}
// Decoder.Binary
for (DecodedArgs decodedArgs : decodedBinaryCallingArgs)
{
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), decodedArgs.args);
if (methodHandle != null)
{
// Decoded Binary Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = DecodedBinaryMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
}
// Decoder.TextStream
for (DecodedArgs decodedArgs : decodedTextStreamCallingArgs)
{
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), decodedArgs.args);
if (methodHandle != null)
{
// Decoded Text Stream
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = DecodedTextStreamMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setTextMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
}
// Decoder.BinaryStream
for (DecodedArgs decodedArgs : decodedBinaryStreamCallingArgs)
{
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), decodedArgs.args);
if (methodHandle != null)
{
// Decoded Binary Stream
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
msgMetadata.sinkClass = DecodedBinaryStreamMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setBinaryMetadata(msgMetadata, onMsg);
continue onmessageloop;
}
}
// == Pong ==
methodHandle = InvokerUtils
.optionalMutatedInvoker(endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), pongCallingArgs);
if (methodHandle != null)
{
// Pong Message
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
metadata.setPongHandle(methodHandle, onMsg);
continue onmessageloop;
}
// Not a valid @OnMessage declaration signature
throw InvalidSignatureException.build(endpointClass, OnMessage.class, onMsg); throw InvalidSignatureException.build(endpointClass, OnMessage.class, onMsg);
} }
} }
@ -710,6 +483,184 @@ public abstract class JavaxWebSocketFrameHandlerFactory
return metadata; return metadata;
} }
private boolean matchOnMessage(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, MessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// Whole Text Message.
MethodHandle methodHandle = getMethodHandle.apply(textCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = StringMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Partial Text Message.
methodHandle = getMethodHandle.apply(textPartialCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = PartialStringMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Whole ByteBuffer Binary Message.
methodHandle = getMethodHandle.apply(binaryBufferCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = ByteBufferMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Partial ByteBuffer Binary Message.
methodHandle = getMethodHandle.apply(binaryPartialBufferCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = PartialByteBufferMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Whole byte[] Binary Message.
methodHandle = getMethodHandle.apply(binaryArrayCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = ByteArrayMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Partial byte[] Binary Message.
methodHandle = getMethodHandle.apply(binaryPartialArrayCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = PartialByteArrayMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// InputStream Binary Message.
methodHandle = getMethodHandle.apply(inputStreamCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = InputStreamMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Reader Text Message.
methodHandle = getMethodHandle.apply(readerCallingArgs);
if (methodHandle != null)
{
msgMetadata.sinkClass = ReaderMessageSink.class;
msgMetadata.handle = methodHandle;
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Pong Message.
MethodHandle pongHandle = getMethodHandle.apply(pongCallingArgs);
if (pongHandle != null)
{
metadata.setPongHandle(pongHandle, onMsg);
return true;
}
return false;
}
private boolean matchDecoders(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, MessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// TODO: we should be able to get this information directly from the AvailableDecoders in the metadata.
List<DecodedArgs> decodedTextCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedTextStreamCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedBinaryCallingArgs = new ArrayList<>();
List<DecodedArgs> decodedBinaryStreamCallingArgs = new ArrayList<>();
for (AvailableDecoders.RegisteredDecoder decoder : metadata.getAvailableDecoders())
{
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(decoder.objectType).required()};
DecodedArgs decodedArgs = new DecodedArgs(decoder, args);
if (decoder.implementsInterface(Decoder.Text.class))
decodedTextCallingArgs.add(decodedArgs);
if (decoder.implementsInterface(Decoder.TextStream.class))
decodedTextStreamCallingArgs.add(decodedArgs);
if (decoder.implementsInterface(Decoder.Binary.class))
decodedBinaryCallingArgs.add(decodedArgs);
if (decoder.implementsInterface(Decoder.BinaryStream.class))
decodedBinaryStreamCallingArgs.add(decodedArgs);
}
MethodHandle methodHandle;
// Decoder.Text
for (DecodedArgs decodedArgs : decodedTextCallingArgs)
{
methodHandle = getMethodHandle.apply(decodedArgs.args);
if (methodHandle != null)
{
msgMetadata.sinkClass = DecodedTextMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
}
// Decoder.Binary
for (DecodedArgs decodedArgs : decodedBinaryCallingArgs)
{
methodHandle = getMethodHandle.apply(decodedArgs.args);
if (methodHandle != null)
{
msgMetadata.sinkClass = DecodedBinaryMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
}
// Try to match Text Stream decoders.
for (DecodedArgs decodedArgs : decodedTextStreamCallingArgs)
{
methodHandle = getMethodHandle.apply(decodedArgs.args);
if (methodHandle != null)
{
msgMetadata.sinkClass = DecodedTextStreamMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
}
// Decoder.BinaryStream
for (DecodedArgs decodedArgs : decodedBinaryStreamCallingArgs)
{
methodHandle = getMethodHandle.apply(decodedArgs.args);
if (methodHandle != null)
{
msgMetadata.sinkClass = DecodedBinaryStreamMessageSink.class;
msgMetadata.handle = methodHandle;
msgMetadata.registeredDecoder = decodedArgs.registeredDecoder;
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
}
return false;
}
private void assertSignatureValid(Class<?> endpointClass, Method method, Class<? extends Annotation> annotationClass) private void assertSignatureValid(Class<?> endpointClass, Method method, Class<? extends Annotation> annotationClass)
{ {
// Test modifiers // Test modifiers

View File

@ -39,7 +39,7 @@ public class DummyFrameHandlerFactory extends JavaxWebSocketFrameHandlerFactory
} }
@Override @Override
public JavaxWebSocketFrameHandlerMetadata createMetadata(Class<?> endpointClass, EndpointConfig endpointConfig) public JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig)
{ {
if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass))
{ {

View File

@ -39,13 +39,7 @@ public class JavaxWebSocketServerFrameHandlerFactory extends JavaxWebSocketClien
} }
@Override @Override
public EndpointConfig newDefaultEndpointConfig(Class<?> endpointClass, String path) public JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig)
{
return new BasicServerEndpointConfig(endpointClass, path);
}
@Override
public JavaxWebSocketFrameHandlerMetadata createMetadata(Class<?> endpointClass, EndpointConfig endpointConfig)
{ {
if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass))
{ {
@ -55,7 +49,7 @@ public class JavaxWebSocketServerFrameHandlerFactory extends JavaxWebSocketClien
ServerEndpoint anno = endpointClass.getAnnotation(ServerEndpoint.class); ServerEndpoint anno = endpointClass.getAnnotation(ServerEndpoint.class);
if (anno == null) if (anno == null)
{ {
return super.createMetadata(endpointClass, endpointConfig); return super.getMetadata(endpointClass, endpointConfig);
} }
UriTemplatePathSpec templatePathSpec = new UriTemplatePathSpec(anno.value()); UriTemplatePathSpec templatePathSpec = new UriTemplatePathSpec(anno.value());

View File

@ -0,0 +1,170 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.javax.tests.server;
import java.net.URI;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.websocket.javax.common.decoders.StringDecoder;
import org.eclipse.jetty.websocket.javax.server.config.JavaxWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.javax.tests.EventSocket;
import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* Example of an annotated echo server discovered via annotation scanning.
*/
public class ServerDecoderTest
{
private static CompletableFuture<EventSocket> annotatedServerSocket = new CompletableFuture<>();
private static CompletableFuture<WSEndpointTracker> configuredServerSocket = new CompletableFuture<>();
private Server server;
private URI serverURI;
public static class EqualsAppendDecoder extends StringDecoder
{
@Override
public String decode(String s)
{
return s + "=";
}
}
public static class PlusAppendDecoder extends StringDecoder
{
@Override
public String decode(String s)
{
return s + "+";
}
}
@ServerEndpoint(value = "/annotated", decoders = {EqualsAppendDecoder.class})
public static class AnnotatedEndpoint extends EventSocket
{
@Override
public void onOpen(Session session, EndpointConfig config)
{
super.onOpen(session, config);
annotatedServerSocket.complete(this);
}
}
public static class ConfiguredEndpoint extends WSEndpointTracker implements MessageHandler.Whole<String>
{
@Override
public void onOpen(Session session, EndpointConfig config)
{
super.onOpen(session, config);
session.addMessageHandler(this);
configuredServerSocket.complete(this);
}
@Override
public void onMessage(String message)
{
super.onWsText(message);
}
}
@BeforeEach
public void startServer() throws Exception
{
server = new Server();
ServerConnector serverConnector = new ServerConnector(server);
server.addConnector(serverConnector);
ServletContextHandler servletContextHandler = new ServletContextHandler(null, "/");
server.setHandler(servletContextHandler);
JavaxWebSocketServletContainerInitializer.configure(servletContextHandler, ((servletContext, serverContainer) ->
{
serverContainer.addEndpoint(AnnotatedEndpoint.class);
ServerEndpointConfig config = ServerEndpointConfig.Builder.create(ConfiguredEndpoint.class, "/configured")
.decoders(Collections.singletonList(PlusAppendDecoder.class))
.build();
serverContainer.addEndpoint(config);
}));
server.start();
serverURI = new URI("ws://localhost:" + serverConnector.getLocalPort());
}
@AfterEach
public void stopServer() throws Exception
{
if (server != null)
server.stop();
}
@Test
public void testAnnotatedDecoder() throws Exception
{
WebSocketContainer client = ContainerProvider.getWebSocketContainer();
EventSocket clientSocket = new EventSocket();
Session session = client.connectToServer(clientSocket, serverURI.resolve("/annotated"));
session.getBasicRemote().sendText("hello world");
EventSocket serverSocket = annotatedServerSocket.get(5, TimeUnit.SECONDS);
assertTrue(serverSocket.openLatch.await(5, TimeUnit.SECONDS));
String msg = serverSocket.messageQueue.poll(5, TimeUnit.SECONDS);
assertThat(msg, is("hello world="));
clientSocket.session.close();
clientSocket.closeLatch.await(5, TimeUnit.SECONDS);
serverSocket.closeLatch.await(5, TimeUnit.SECONDS);
}
@Test
public void testConfiguredDecoder() throws Exception
{
WebSocketContainer client = ContainerProvider.getWebSocketContainer();
EventSocket clientSocket = new EventSocket();
Session session = client.connectToServer(clientSocket, serverURI.resolve("/configured"));
session.getBasicRemote().sendText("hello world");
WSEndpointTracker serverSocket = configuredServerSocket.get(5, TimeUnit.SECONDS);
assertTrue(serverSocket.openLatch.await(5, TimeUnit.SECONDS));
String msg = serverSocket.messageQueue.poll(5, TimeUnit.SECONDS);
assertThat(msg, is("hello world+"));
clientSocket.session.close();
clientSocket.closeLatch.await(5, TimeUnit.SECONDS);
serverSocket.closeLatch.await(5, TimeUnit.SECONDS);
}
}