Issue #3428 - fix decoder list matching to get test working

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-05-21 08:52:01 +10:00
parent 4b19c19815
commit 8e554c7d13
6 changed files with 327 additions and 472 deletions

View File

@ -18,8 +18,6 @@
package org.eclipse.jetty.websocket.javax.common; package org.eclipse.jetty.websocket.javax.common;
import java.io.InputStream;
import java.io.Reader;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import javax.websocket.PongMessage; import javax.websocket.PongMessage;
import javax.websocket.Session; import javax.websocket.Session;
@ -29,10 +27,10 @@ import org.eclipse.jetty.websocket.util.InvokerUtils;
// The different kind of @OnMessage method parameter signatures expected. // The different kind of @OnMessage method parameter signatures expected.
public class JavaxWebSocketCallingArgs public class JavaxWebSocketCallingArgs
{ {
static final InvokerUtils.Arg[] textCallingArgs = new InvokerUtils.Arg[]{ static InvokerUtils.Arg[] getArgsFor(Class<?> objectType)
new InvokerUtils.Arg(Session.class), {
new InvokerUtils.Arg(String.class).required() return new InvokerUtils.Arg[]{new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(objectType).required()};
}; }
static final InvokerUtils.Arg[] textPartialCallingArgs = new InvokerUtils.Arg[]{ static final InvokerUtils.Arg[] textPartialCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(Session.class),
@ -40,38 +38,18 @@ public class JavaxWebSocketCallingArgs
new InvokerUtils.Arg(boolean.class).required() new InvokerUtils.Arg(boolean.class).required()
}; };
static final InvokerUtils.Arg[] binaryBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required()
};
static final InvokerUtils.Arg[] binaryPartialBufferCallingArgs = new InvokerUtils.Arg[]{ static final InvokerUtils.Arg[] binaryPartialBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required(), new InvokerUtils.Arg(ByteBuffer.class).required(),
new InvokerUtils.Arg(boolean.class).required() new InvokerUtils.Arg(boolean.class).required()
}; };
static final InvokerUtils.Arg[] binaryArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required()
};
static final InvokerUtils.Arg[] binaryPartialArrayCallingArgs = new InvokerUtils.Arg[]{ static final InvokerUtils.Arg[] binaryPartialArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required(), new InvokerUtils.Arg(byte[].class).required(),
new InvokerUtils.Arg(boolean.class).required() new InvokerUtils.Arg(boolean.class).required()
}; };
static final InvokerUtils.Arg[] inputStreamCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(InputStream.class).required()
};
static final InvokerUtils.Arg[] readerCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(Reader.class).required()
};
static final InvokerUtils.Arg[] pongCallingArgs = new InvokerUtils.Arg[]{ static final InvokerUtils.Arg[] pongCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(PongMessage.class).required() new InvokerUtils.Arg(PongMessage.class).required()

View File

@ -460,12 +460,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(methodHandle); metadata.setMethodHandle(methodHandle);
metadata.setRegisteredDecoder(registeredDecoder);
if (registeredDecoder.implementsInterface(Decoder.Binary.class)) if (registeredDecoder.implementsInterface(Decoder.Binary.class))
{ {
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
List<RegisteredDecoder> binaryDecoders = availableDecoders.getBinaryDecoders(clazz); List<RegisteredDecoder> binaryDecoders = availableDecoders.getBinaryDecoders(clazz);
metadata.setRegisteredDecoders(binaryDecoders);
MessageSink messageSink = new DecodedBinaryMessageSink<T>(coreSession, methodHandle, binaryDecoders); MessageSink messageSink = new DecodedBinaryMessageSink<T>(coreSession, methodHandle, binaryDecoders);
metadata.setSinkClass(messageSink.getClass()); metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
@ -475,6 +475,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
{ {
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
List<RegisteredDecoder> binaryStreamDecoders = availableDecoders.getBinaryStreamDecoders(clazz); List<RegisteredDecoder> binaryStreamDecoders = availableDecoders.getBinaryStreamDecoders(clazz);
metadata.setRegisteredDecoders(binaryStreamDecoders);
MessageSink messageSink = new DecodedBinaryStreamMessageSink<T>(coreSession, methodHandle, binaryStreamDecoders); MessageSink messageSink = new DecodedBinaryStreamMessageSink<T>(coreSession, methodHandle, binaryStreamDecoders);
metadata.setSinkClass(messageSink.getClass()); metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
@ -484,6 +485,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
{ {
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName()); assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
List<RegisteredDecoder> textDecoders = availableDecoders.getTextDecoders(clazz); List<RegisteredDecoder> textDecoders = availableDecoders.getTextDecoders(clazz);
metadata.setRegisteredDecoders(textDecoders);
MessageSink messageSink = new DecodedTextMessageSink<T>(coreSession, methodHandle, textDecoders); MessageSink messageSink = new DecodedTextMessageSink<T>(coreSession, methodHandle, textDecoders);
metadata.setSinkClass(messageSink.getClass()); metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
@ -493,6 +495,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
{ {
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName()); assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
List<RegisteredDecoder> textStreamDecoders = availableDecoders.getTextStreamDecoders(clazz); List<RegisteredDecoder> textStreamDecoders = availableDecoders.getTextStreamDecoders(clazz);
metadata.setRegisteredDecoders(textStreamDecoders);
MessageSink messageSink = new DecodedTextStreamMessageSink<T>(coreSession, methodHandle, textStreamDecoders); MessageSink messageSink = new DecodedTextStreamMessageSink<T>(coreSession, methodHandle, textStreamDecoders);
metadata.setSinkClass(messageSink.getClass()); metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
@ -618,7 +621,6 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
// Use JSR356 PongMessage interface // Use JSR356 PongMessage interface
JavaxWebSocketPongMessage pongMessage = new JavaxWebSocketPongMessage(payload); JavaxWebSocketPongMessage pongMessage = new JavaxWebSocketPongMessage(payload);
pongHandle.invoke(pongMessage); pongHandle.invoke(pongMessage);
} }
catch (Throwable cause) catch (Throwable cause)

View File

@ -29,9 +29,9 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream;
import javax.websocket.CloseReason; import javax.websocket.CloseReason;
import javax.websocket.Decoder; import javax.websocket.Decoder;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig; import javax.websocket.EndpointConfig;
import javax.websocket.OnClose; import javax.websocket.OnClose;
import javax.websocket.OnError; import javax.websocket.OnError;
@ -51,17 +51,13 @@ import org.eclipse.jetty.websocket.util.InvalidSignatureException;
import org.eclipse.jetty.websocket.util.InvalidWebSocketException; import org.eclipse.jetty.websocket.util.InvalidWebSocketException;
import org.eclipse.jetty.websocket.util.InvokerUtils; import org.eclipse.jetty.websocket.util.InvokerUtils;
import org.eclipse.jetty.websocket.util.ReflectUtils; import org.eclipse.jetty.websocket.util.ReflectUtils;
import org.eclipse.jetty.websocket.util.messages.ByteArrayMessageSink;
import org.eclipse.jetty.websocket.util.messages.ByteBufferMessageSink;
import org.eclipse.jetty.websocket.util.messages.InputStreamMessageSink;
import org.eclipse.jetty.websocket.util.messages.MessageSink; import org.eclipse.jetty.websocket.util.messages.MessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialByteArrayMessageSink; import org.eclipse.jetty.websocket.util.messages.PartialByteArrayMessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialByteBufferMessageSink; import org.eclipse.jetty.websocket.util.messages.PartialByteBufferMessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink; import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink;
import org.eclipse.jetty.websocket.util.messages.ReaderMessageSink;
import org.eclipse.jetty.websocket.util.messages.StringMessageSink;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static org.eclipse.jetty.websocket.javax.common.JavaxWebSocketCallingArgs.getArgsFor;
public abstract class JavaxWebSocketFrameHandlerFactory public abstract class JavaxWebSocketFrameHandlerFactory
{ {
@ -154,6 +150,321 @@ public abstract class JavaxWebSocketFrameHandlerFactory
config); config);
} }
public static MessageSink createMessageSink(JavaxWebSocketSession session, JavaxWebSocketMessageMetadata msgMetadata)
{
if (msgMetadata == null)
return null;
try
{
MethodHandles.Lookup lookup = getServerMethodHandleLookup();
if (AbstractDecodedMessageSink.class.isAssignableFrom(msgMetadata.getSinkClass()))
{
MethodHandle ctorHandle = lookup.findConstructor(msgMetadata.getSinkClass(),
MethodType.methodType(void.class, CoreSession.class, MethodHandle.class, List.class));
List<RegisteredDecoder> registeredDecoders = msgMetadata.getRegisteredDecoders();
return (MessageSink)ctorHandle.invoke(session.getCoreSession(), msgMetadata.getMethodHandle(), registeredDecoders);
}
else
{
MethodHandle ctorHandle = lookup.findConstructor(msgMetadata.getSinkClass(),
MethodType.methodType(void.class, CoreSession.class, MethodHandle.class));
return (MessageSink)ctorHandle.invoke(session.getCoreSession(), msgMetadata.getMethodHandle());
}
}
catch (NoSuchMethodException e)
{
throw new RuntimeException("Missing expected MessageSink constructor found at: " + msgMetadata.getSinkClass().getName(), e);
}
catch (IllegalAccessException | InstantiationException | InvocationTargetException e)
{
throw new RuntimeException("Unable to create MessageSink: " + msgMetadata.getSinkClass().getName(), e);
}
catch (RuntimeException e)
{
throw e;
}
catch (Throwable t)
{
throw new RuntimeException(t);
}
}
public static MethodHandle wrapNonVoidReturnType(MethodHandle handle, JavaxWebSocketSession session)
{
if (handle == null)
return null;
if (handle.type().returnType() == Void.TYPE)
return handle;
// Technique from https://stackoverflow.com/questions/48505787/methodhandle-with-general-non-void-return-filter
// Change the return type of the to be Object so it will match exact with JavaxWebSocketSession.filterReturnType(Object)
handle = handle.asType(handle.type().changeReturnType(Object.class));
// Filter the method return type to a call to JavaxWebSocketSession.filterReturnType() bound to this session
handle = MethodHandles.filterReturnValue(handle, FILTER_RETURN_TYPE_METHOD.bindTo(session));
return handle;
}
private MethodHandle toMethodHandle(MethodHandles.Lookup lookup, Method method)
{
try
{
return lookup.unreflect(method);
}
catch (IllegalAccessException e)
{
throw new RuntimeException("Unable to access method " + method, e);
}
}
protected JavaxWebSocketFrameHandlerMetadata createEndpointMetadata(Class<? extends Endpoint> endpointClass, EndpointConfig endpointConfig)
{
JavaxWebSocketFrameHandlerMetadata metadata = new JavaxWebSocketFrameHandlerMetadata(endpointConfig);
MethodHandles.Lookup lookup = getApplicationMethodHandleLookup(endpointClass);
Method openMethod = ReflectUtils.findMethod(endpointClass, "onOpen", Session.class, EndpointConfig.class);
MethodHandle open = toMethodHandle(lookup, openMethod);
metadata.setOpenHandler(open, openMethod);
Method closeMethod = ReflectUtils.findMethod(endpointClass, "onClose", Session.class, CloseReason.class);
MethodHandle close = toMethodHandle(lookup, closeMethod);
metadata.setCloseHandler(close, closeMethod);
Method errorMethod = ReflectUtils.findMethod(endpointClass, "onError", Session.class, Throwable.class);
MethodHandle error = toMethodHandle(lookup, errorMethod);
metadata.setErrorHandler(error, errorMethod);
return metadata;
}
protected JavaxWebSocketFrameHandlerMetadata discoverJavaxFrameHandlerMetadata(Class<?> endpointClass, JavaxWebSocketFrameHandlerMetadata metadata)
{
MethodHandles.Lookup lookup = getApplicationMethodHandleLookup(endpointClass);
Method onmethod;
// OnOpen [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnOpen.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnOpen.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg ENDPOINT_CONFIG = new InvokerUtils.Arg(EndpointConfig.class);
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, ENDPOINT_CONFIG);
metadata.setOpenHandler(methodHandle, onmethod);
}
// OnClose [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnClose.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnClose.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg CLOSE_REASON = new InvokerUtils.Arg(CloseReason.class);
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CLOSE_REASON);
metadata.setCloseHandler(methodHandle, onmethod);
}
// OnError [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnError.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnError.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg CAUSE = new InvokerUtils.Arg(Throwable.class).required();
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CAUSE);
metadata.setErrorHandler(methodHandle, onmethod);
}
// OnMessage [0..2]
Method[] onMessages = ReflectUtils.findAnnotatedMethods(endpointClass, OnMessage.class);
if (onMessages != null && onMessages.length > 0)
{
for (Method onMsg : onMessages)
{
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
OnMessage onMessageAnno = onMsg.getAnnotation(OnMessage.class);
long annotationMaxMessageSize = onMessageAnno.maxMessageSize();
if (annotationMaxMessageSize > Integer.MAX_VALUE)
{
throw new InvalidWebSocketException(String.format("Value too large: %s#%s - @OnMessage.maxMessageSize=%,d > Integer.MAX_VALUE",
endpointClass.getName(), onMsg.getName(), annotationMaxMessageSize));
}
// Create MessageMetadata and set annotated maxMessageSize if it is not the default value.
JavaxWebSocketMessageMetadata msgMetadata = new JavaxWebSocketMessageMetadata();
if (annotationMaxMessageSize != -1)
msgMetadata.setMaxMessageSize((int)annotationMaxMessageSize);
// Function to search for matching MethodHandle for the endpointClass given a signature.
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle = (signature) ->
InvokerUtils.optionalMutatedInvoker(lookup, endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), signature);
// Try to match from available decoders (includes primitive types).
if (matchDecoders(onMsg, metadata, msgMetadata, getMethodHandle))
continue;
// No decoders matched try partial signatures and pong signatures.
if (matchOnMessage(onMsg, metadata, msgMetadata, getMethodHandle))
continue;
// Not a valid @OnMessage declaration signature.
throw InvalidSignatureException.build(endpointClass, OnMessage.class, onMsg);
}
}
return metadata;
}
private boolean matchOnMessage(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, JavaxWebSocketMessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// Partial Text Message.
MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialStringMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Partial ByteBuffer Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryPartialBufferCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialByteBufferMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Partial byte[] Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryPartialArrayCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialByteArrayMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Pong Message.
MethodHandle pongHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.pongCallingArgs);
if (pongHandle != null)
{
metadata.setPongHandle(pongHandle, onMsg);
return true;
}
return false;
}
private boolean matchDecoders(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, JavaxWebSocketMessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// Get the first decoder match.
RegisteredDecoder firstDecoder = metadata.getAvailableDecoders().stream()
.filter(registeredDecoder -> getMethodHandle.apply(getArgsFor(registeredDecoder.objectType)) != null)
.findFirst()
.orElse(null);
if (firstDecoder == null)
return false;
// Assemble a list of matching decoders which implement the interface type of the first matching decoder found.
List<RegisteredDecoder> decoders = new ArrayList<>();
Class<? extends Decoder> interfaceType = firstDecoder.interfaceType;
metadata.getAvailableDecoders().stream()
.filter(decoder ->
{
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(decoder.objectType).required()};
return decoder.interfaceType.equals(interfaceType) && (getMethodHandle.apply(args) != null);
}).forEach(decoders::add);
msgMetadata.setRegisteredDecoders(decoders);
// Get the general methodHandle which applies to all the decoders in the list.
Class<?> objectType = firstDecoder.objectType;
for (RegisteredDecoder decoder : decoders)
{
if (decoder.objectType.isAssignableFrom(objectType))
objectType = decoder.objectType;
};
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(objectType).required()};
MethodHandle methodHandle = getMethodHandle.apply(args);
msgMetadata.setMethodHandle(methodHandle);
// Set the sinkClass and then set the MessageMetadata on the FrameHandlerMetadata
if (interfaceType.equals(Decoder.Text.class))
{
msgMetadata.setSinkClass(DecodedTextMessageSink.class);
metadata.setTextMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.Binary.class))
{
msgMetadata.setSinkClass(DecodedBinaryMessageSink.class);
metadata.setBinaryMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.TextStream.class))
{
msgMetadata.setSinkClass(DecodedTextStreamMessageSink.class);
metadata.setTextMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.BinaryStream.class))
{
msgMetadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
metadata.setBinaryMetadata(msgMetadata, onMsg);
}
return true;
}
private void assertSignatureValid(Class<?> endpointClass, Method method, Class<? extends Annotation> annotationClass)
{
// Test modifiers
int mods = method.getModifiers();
if (!Modifier.isPublic(mods))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" method must be public: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
if (Modifier.isStatic(mods))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" method must not be static: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
// Test return type
Class<?> returnType = method.getReturnType();
if ((returnType == Void.TYPE) || (returnType == Void.class))
{
// Void is 100% valid, always
return;
}
if (!OnMessage.class.isAssignableFrom(annotationClass))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" return must be void: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
}
/** /**
* Bind the URI Template Variables to their provided values, converting to the type * Bind the URI Template Variables to their provided values, converting to the type
* that the MethodHandle target has declared. * that the MethodHandle target has declared.
@ -260,382 +571,6 @@ public abstract class JavaxWebSocketFrameHandlerFactory
return retHandle; return retHandle;
} }
public static MessageSink createMessageSink(JavaxWebSocketSession session, JavaxWebSocketMessageMetadata msgMetadata)
{
if (msgMetadata == null)
return null;
try
{
MethodHandles.Lookup lookup = getServerMethodHandleLookup();
if (AbstractDecodedMessageSink.class.isAssignableFrom(msgMetadata.getSinkClass()))
{
MethodHandle ctorHandle = lookup.findConstructor(msgMetadata.getSinkClass(),
MethodType.methodType(void.class, CoreSession.class, MethodHandle.class, List.class));
List<RegisteredDecoder> registeredDecoders = msgMetadata.getRegisteredDecoders();
return (MessageSink)ctorHandle.invoke(session.getCoreSession(), msgMetadata.getMethodHandle(), registeredDecoders);
}
else
{
MethodHandle ctorHandle = lookup.findConstructor(msgMetadata.getSinkClass(),
MethodType.methodType(void.class, CoreSession.class, MethodHandle.class));
return (MessageSink)ctorHandle.invoke(session.getCoreSession(), msgMetadata.getMethodHandle());
}
}
catch (NoSuchMethodException e)
{
throw new RuntimeException("Missing expected MessageSink constructor found at: " + msgMetadata.getSinkClass().getName(), e);
}
catch (IllegalAccessException | InstantiationException | InvocationTargetException e)
{
throw new RuntimeException("Unable to create MessageSink: " + msgMetadata.getSinkClass().getName(), e);
}
catch (RuntimeException e)
{
throw e;
}
catch (Throwable t)
{
throw new RuntimeException(t);
}
}
public static MethodHandle wrapNonVoidReturnType(MethodHandle handle, JavaxWebSocketSession session)
{
if (handle == null)
return null;
if (handle.type().returnType() == Void.TYPE)
return handle;
// Technique from https://stackoverflow.com/questions/48505787/methodhandle-with-general-non-void-return-filter
// Change the return type of the to be Object so it will match exact with JavaxWebSocketSession.filterReturnType(Object)
handle = handle.asType(handle.type().changeReturnType(Object.class));
// Filter the method return type to a call to JavaxWebSocketSession.filterReturnType() bound to this session
handle = MethodHandles.filterReturnValue(handle, FILTER_RETURN_TYPE_METHOD.bindTo(session));
return handle;
}
private MethodHandle toMethodHandle(MethodHandles.Lookup lookup, Method method)
{
try
{
return lookup.unreflect(method);
}
catch (IllegalAccessException e)
{
throw new RuntimeException("Unable to access method " + method, e);
}
}
protected JavaxWebSocketFrameHandlerMetadata createEndpointMetadata(Class<? extends javax.websocket.Endpoint> endpointClass, EndpointConfig endpointConfig)
{
JavaxWebSocketFrameHandlerMetadata metadata = new JavaxWebSocketFrameHandlerMetadata(endpointConfig);
MethodHandles.Lookup lookup = getApplicationMethodHandleLookup(endpointClass);
Method openMethod = ReflectUtils.findMethod(endpointClass, "onOpen",
javax.websocket.Session.class, javax.websocket.EndpointConfig.class);
MethodHandle open = toMethodHandle(lookup, openMethod);
metadata.setOpenHandler(open, openMethod);
Method closeMethod = ReflectUtils.findMethod(endpointClass, "onClose",
javax.websocket.Session.class, javax.websocket.CloseReason.class);
MethodHandle close = toMethodHandle(lookup, closeMethod);
metadata.setCloseHandler(close, closeMethod);
Method errorMethod = ReflectUtils.findMethod(endpointClass, "onError",
javax.websocket.Session.class, Throwable.class);
MethodHandle error = toMethodHandle(lookup, errorMethod);
metadata.setErrorHandler(error, errorMethod);
return metadata;
}
protected JavaxWebSocketFrameHandlerMetadata discoverJavaxFrameHandlerMetadata(Class<?> endpointClass, JavaxWebSocketFrameHandlerMetadata metadata)
{
MethodHandles.Lookup lookup = getApplicationMethodHandleLookup(endpointClass);
Method onmethod;
// OnOpen [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnOpen.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnOpen.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg ENDPOINT_CONFIG = new InvokerUtils.Arg(EndpointConfig.class);
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, ENDPOINT_CONFIG);
metadata.setOpenHandler(methodHandle, onmethod);
}
// OnClose [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnClose.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnClose.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg CLOSE_REASON = new InvokerUtils.Arg(CloseReason.class);
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CLOSE_REASON);
metadata.setCloseHandler(methodHandle, onmethod);
}
// OnError [0..1]
onmethod = ReflectUtils.findAnnotatedMethod(endpointClass, OnError.class);
if (onmethod != null)
{
assertSignatureValid(endpointClass, onmethod, OnError.class);
final InvokerUtils.Arg SESSION = new InvokerUtils.Arg(Session.class);
final InvokerUtils.Arg CAUSE = new InvokerUtils.Arg(Throwable.class).required();
MethodHandle methodHandle = InvokerUtils
.mutatedInvoker(lookup, endpointClass, onmethod, paramIdentifier, metadata.getNamedTemplateVariables(), SESSION, CAUSE);
metadata.setErrorHandler(methodHandle, onmethod);
}
// OnMessage [0..2]
Method[] onMessages = ReflectUtils.findAnnotatedMethods(endpointClass, OnMessage.class);
if (onMessages != null && onMessages.length > 0)
{
for (Method onMsg : onMessages)
{
assertSignatureValid(endpointClass, onMsg, OnMessage.class);
OnMessage onMessageAnno = onMsg.getAnnotation(OnMessage.class);
long annotationMaxMessageSize = onMessageAnno.maxMessageSize();
if (annotationMaxMessageSize > Integer.MAX_VALUE)
{
throw new InvalidWebSocketException(String.format("Value too large: %s#%s - @OnMessage.maxMessageSize=%,d > Integer.MAX_VALUE",
endpointClass.getName(), onMsg.getName(), annotationMaxMessageSize));
}
// Create MessageMetadata and set annotated maxMessageSize if it is not the default value.
JavaxWebSocketMessageMetadata msgMetadata = new JavaxWebSocketMessageMetadata();
if (annotationMaxMessageSize != -1)
msgMetadata.setMaxMessageSize((int)annotationMaxMessageSize);
// Function to search for matching MethodHandle for the endpointClass given a signature.
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle = (signature) ->
InvokerUtils.optionalMutatedInvoker(lookup, endpointClass, onMsg, paramIdentifier, metadata.getNamedTemplateVariables(), signature);
// Try to match from available decoders.
if (matchDecoders(onMsg, metadata, msgMetadata, getMethodHandle))
continue;
// No decoders matched try basic signatures to call onMessage directly.
if (matchOnMessage(onMsg, metadata, msgMetadata, getMethodHandle))
continue;
// Not a valid @OnMessage declaration signature.
throw InvalidSignatureException.build(endpointClass, OnMessage.class, onMsg);
}
}
return metadata;
}
private boolean matchOnMessage(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, JavaxWebSocketMessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// Whole Text Message.
MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(StringMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Partial Text Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialStringMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Whole ByteBuffer Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryBufferCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(ByteBufferMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Partial ByteBuffer Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryPartialBufferCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialByteBufferMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Whole byte[] Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryArrayCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(ByteArrayMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Partial byte[] Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.binaryPartialArrayCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialByteArrayMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// InputStream Binary Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.inputStreamCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(InputStreamMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setBinaryMetadata(msgMetadata, onMsg);
return true;
}
// Reader Text Message.
methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.readerCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(ReaderMessageSink.class);
msgMetadata.setMethodHandle(methodHandle);
metadata.setTextMetadata(msgMetadata, onMsg);
return true;
}
// Pong Message.
MethodHandle pongHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.pongCallingArgs);
if (pongHandle != null)
{
metadata.setPongHandle(pongHandle, onMsg);
return true;
}
return false;
}
private boolean matchDecoders(Method onMsg, JavaxWebSocketFrameHandlerMetadata metadata, JavaxWebSocketMessageMetadata msgMetadata,
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// We need to get all the decoders which match not just the first.
Stream<RegisteredDecoder> matchedDecodersStream = metadata.getAvailableDecoders().stream().filter(registeredDecoder ->
{
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(registeredDecoder.objectType).required()};
return getMethodHandle.apply(args) != null;
});
// Use the interface type of the first matched decoder.
RegisteredDecoder firstDecoder = matchedDecodersStream.findFirst().orElse(null);
if (firstDecoder == null)
return false;
// TODO: COMMENT
List<RegisteredDecoder> decoders = new ArrayList<>();
Class<? extends Decoder> interfaceType = firstDecoder.interfaceType;
metadata.getAvailableDecoders().stream()
.filter(registeredDecoder -> registeredDecoder.interfaceType.equals(interfaceType))
.forEach(decoders::add);
// Get the original argument type.
Class<?> type = firstDecoder.objectType;
for (Class<?> clazz : onMsg.getParameterTypes())
{
if (clazz.isAssignableFrom(firstDecoder.objectType))
type = clazz;
}
InvokerUtils.Arg[] generalArgs = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(type).required()};
MethodHandle generalMethodHandle = getMethodHandle.apply(generalArgs);
if (generalMethodHandle == null)
{
// TODO: warn or throw
return false;
}
msgMetadata.setRegisteredDecoders(decoders);
msgMetadata.setMethodHandle(generalMethodHandle);
if (interfaceType.equals(Decoder.Text.class))
{
msgMetadata.setSinkClass(DecodedTextMessageSink.class);
metadata.setTextMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.Binary.class))
{
msgMetadata.setSinkClass(DecodedBinaryMessageSink.class);
metadata.setBinaryMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.TextStream.class))
{
msgMetadata.setSinkClass(DecodedTextStreamMessageSink.class);
metadata.setTextMetadata(msgMetadata, onMsg);
}
else if (interfaceType.equals(Decoder.BinaryStream.class))
{
msgMetadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
metadata.setBinaryMetadata(msgMetadata, onMsg);
}
return true;
}
private void assertSignatureValid(Class<?> endpointClass, Method method, Class<? extends Annotation> annotationClass)
{
// Test modifiers
int mods = method.getModifiers();
if (!Modifier.isPublic(mods))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" method must be public: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
if (Modifier.isStatic(mods))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" method must not be static: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
// Test return type
Class<?> returnType = method.getReturnType();
if ((returnType == Void.TYPE) || (returnType == Void.class))
{
// Void is 100% valid, always
return;
}
if (!OnMessage.class.isAssignableFrom(annotationClass))
{
StringBuilder err = new StringBuilder();
err.append("@").append(annotationClass.getSimpleName());
err.append(" return must be void: ");
ReflectUtils.append(err, endpointClass, method);
throw new InvalidSignatureException(err.toString());
}
}
/** /**
* <p> * <p>
* Gives a {@link MethodHandles.Lookup} instance to be used to find methods in server classes. * Gives a {@link MethodHandles.Lookup} instance to be used to find methods in server classes.

View File

@ -33,7 +33,7 @@ public class SessionTracker extends AbstractLifeCycle implements JavaxWebSocketS
{ {
private static final Logger LOG = LoggerFactory.getLogger(SessionTracker.class); private static final Logger LOG = LoggerFactory.getLogger(SessionTracker.class);
private CopyOnWriteArraySet<JavaxWebSocketSession> sessions = new CopyOnWriteArraySet<>(); private final CopyOnWriteArraySet<JavaxWebSocketSession> sessions = new CopyOnWriteArraySet<>();
public Set<Session> getSessions() public Set<Session> getSessions()
{ {

View File

@ -1,60 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.javax.common.decoders;
import java.nio.ByteBuffer;
import javax.websocket.DecodeException;
import javax.websocket.Decoder;
import javax.websocket.PongMessage;
import org.eclipse.jetty.util.BufferUtil;
public class PongMessageDecoder extends AbstractDecoder implements Decoder.Binary<PongMessage>
{
private static class PongMsg implements PongMessage
{
private final ByteBuffer bytes;
public PongMsg(ByteBuffer buf)
{
int len = buf.remaining();
this.bytes = ByteBuffer.allocate(len);
BufferUtil.put(buf, this.bytes);
BufferUtil.flipToFlush(this.bytes, 0);
}
@Override
public ByteBuffer getApplicationData()
{
return this.bytes;
}
}
@Override
public PongMessage decode(ByteBuffer bytes) throws DecodeException
{
return new PongMsg(bytes);
}
@Override
public boolean willDecode(ByteBuffer bytes)
{
return true;
}
}

View File

@ -79,7 +79,7 @@ public class DecoderListTest
Arguments.of("=DecodeEquals", "DecodeEquals="), Arguments.of("=DecodeEquals", "DecodeEquals="),
Arguments.of("+DecodePlus", "DecodePlus+"), Arguments.of("+DecodePlus", "DecodePlus+"),
Arguments.of("-DecodeMinus", "DecodeMinus-"), Arguments.of("-DecodeMinus", "DecodeMinus-"),
Arguments.of("DecodeNoMatch", null) Arguments.of("DecodeNoMatch", "DecodeNoMatch")
); );
} }