Issue #3428 - cleanups and simplify MessageHandler registration

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-05-21 14:58:46 +10:00
parent 8e554c7d13
commit 43e3cdc4e3
4 changed files with 107 additions and 104 deletions

View File

@ -19,11 +19,9 @@
package org.eclipse.jetty.websocket.javax.common;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
@ -368,54 +366,32 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
return textMetadata;
}
private void assertBasicTypeNotRegistered(byte basicWebSocketType, Object messageImpl, String replacement)
{
if (messageImpl != null)
{
throw new IllegalStateException(
"Cannot register " + replacement + ": Basic WebSocket type " + OpCode.name(basicWebSocketType) + " is already registered");
}
}
public <T> void addMessageHandler(Class<T> clazz, MessageHandler.Partial<T> handler)
{
try
{
MethodHandles.Lookup lookup = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup();
MethodHandle partialMessageHandler = lookup
.findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class));
partialMessageHandler = partialMessageHandler.bindTo(handler);
MethodHandle methodHandle = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup()
.findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class))
.bindTo(handler);
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(methodHandle);
byte basicType;
// MessageHandler.Partial has no decoder support!
if (byte[].class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialByteArrayMessageSink(coreSession, partialMessageHandler);
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
basicType = OpCode.BINARY;
metadata.setSinkClass(PartialByteArrayMessageSink.class);
this.binaryMetadata = metadata;
}
else if (ByteBuffer.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialByteBufferMessageSink(coreSession, partialMessageHandler);
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
basicType = OpCode.BINARY;
metadata.setSinkClass(PartialByteBufferMessageSink.class);
this.binaryMetadata = metadata;
}
else if (String.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialStringMessageSink(coreSession, partialMessageHandler);
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
basicType = OpCode.TEXT;
metadata.setSinkClass(PartialStringMessageSink.class);
this.textMetadata = metadata;
}
else
{
@ -423,6 +399,9 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
"Unable to add " + handler.getClass().getName() + " with type " + clazz + ": only supported types byte[], " + ByteBuffer.class.getName() +
", " + String.class.getName());
}
// Register the Metadata as a MessageHandler.
registerMessageHandler(clazz, handler, basicType, metadata);
}
catch (NoSuchMethodException e)
{
@ -444,68 +423,52 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
if (PongMessage.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.PONG, this.pongHandle, handler.getClass().getName());
assertBasicTypeNotRegistered(OpCode.PONG, handler);
this.pongHandle = methodHandle;
registerMessageHandler(OpCode.PONG, clazz, handler, null);
return;
}
AvailableDecoders availableDecoders = session.getDecoders();
RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz);
if (registeredDecoder == null)
throw new IllegalStateException("Unable to find Decoder for type: " + clazz);
// Create the message metadata specific to the MessageHandler type.
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(methodHandle);
byte basicType;
if (registeredDecoder.implementsInterface(Decoder.Binary.class))
{
basicType = OpCode.BINARY;
metadata.setRegisteredDecoders(availableDecoders.getBinaryDecoders(clazz));
metadata.setSinkClass(DecodedBinaryMessageSink.class);
}
else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class))
{
basicType = OpCode.BINARY;
metadata.setRegisteredDecoders(availableDecoders.getBinaryStreamDecoders(clazz));
metadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
}
else if (registeredDecoder.implementsInterface(Decoder.Text.class))
{
basicType = OpCode.TEXT;
metadata.setRegisteredDecoders(availableDecoders.getTextDecoders(clazz));
metadata.setSinkClass(DecodedTextMessageSink.class);
}
else if (registeredDecoder.implementsInterface(Decoder.TextStream.class))
{
basicType = OpCode.TEXT;
metadata.setRegisteredDecoders(availableDecoders.getTextStreamDecoders(clazz));
metadata.setSinkClass(DecodedTextStreamMessageSink.class);
}
else
{
AvailableDecoders availableDecoders = session.getDecoders();
RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz);
if (registeredDecoder == null)
{
throw new IllegalStateException("Unable to find Decoder for type: " + clazz);
}
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(methodHandle);
if (registeredDecoder.implementsInterface(Decoder.Binary.class))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
List<RegisteredDecoder> binaryDecoders = availableDecoders.getBinaryDecoders(clazz);
metadata.setRegisteredDecoders(binaryDecoders);
MessageSink messageSink = new DecodedBinaryMessageSink<T>(coreSession, methodHandle, binaryDecoders);
metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
this.binaryMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
List<RegisteredDecoder> binaryStreamDecoders = availableDecoders.getBinaryStreamDecoders(clazz);
metadata.setRegisteredDecoders(binaryStreamDecoders);
MessageSink messageSink = new DecodedBinaryStreamMessageSink<T>(coreSession, methodHandle, binaryStreamDecoders);
metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
this.binaryMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.Text.class))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
List<RegisteredDecoder> textDecoders = availableDecoders.getTextDecoders(clazz);
metadata.setRegisteredDecoders(textDecoders);
MessageSink messageSink = new DecodedTextMessageSink<T>(coreSession, methodHandle, textDecoders);
metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
this.textMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.TextStream.class))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
List<RegisteredDecoder> textStreamDecoders = availableDecoders.getTextStreamDecoders(clazz);
metadata.setRegisteredDecoders(textStreamDecoders);
MessageSink messageSink = new DecodedTextStreamMessageSink<T>(coreSession, methodHandle, textStreamDecoders);
metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
this.textMetadata = metadata;
}
else
{
throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + clazz + " is unrecognized by declared decoders");
}
throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + clazz + " is unrecognized by declared decoders");
}
// Register the Metadata as a MessageHandler.
registerMessageHandler(clazz, handler, basicType, metadata);
}
catch (NoSuchMethodException e)
{
@ -517,6 +480,50 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
}
}
private void assertBasicTypeNotRegistered(byte basicWebSocketType, MessageHandler replacement)
{
Object messageImpl;
switch (basicWebSocketType)
{
case OpCode.TEXT:
messageImpl = textSink;
break;
case OpCode.BINARY:
messageImpl = binarySink;
break;
case OpCode.PONG:
messageImpl = pongHandle;
break;
default:
throw new IllegalStateException();
}
if (messageImpl != null)
{
throw new IllegalStateException("Cannot register " + replacement.getClass().getName() +
": Basic WebSocket type " + OpCode.name(basicWebSocketType) + " is already registered");
}
}
private void registerMessageHandler(Class<?> clazz, MessageHandler handler, byte basicMessageType, JavaxWebSocketMessageMetadata metadata)
{
assertBasicTypeNotRegistered(basicMessageType, handler);
MessageSink messageSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, metadata);
switch (basicMessageType)
{
case OpCode.TEXT:
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
this.textMetadata = metadata;
break;
case OpCode.BINARY:
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
this.binaryMetadata = metadata;
break;
default:
throw new IllegalStateException();
}
}
private <T> MessageSink registerMessageHandler(byte basicWebSocketMessageType, Class<T> handlerType, MessageHandler handler, MessageSink messageSink)
{
synchronized (messageHandlerMap)
@ -563,7 +570,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
this.binarySink = null;
break;
default:
break; // TODO ISE?
throw new IllegalStateException();
}
}
}

View File

@ -327,7 +327,7 @@ public abstract class JavaxWebSocketFrameHandlerFactory
Function<InvokerUtils.Arg[], MethodHandle> getMethodHandle)
{
// Partial Text Message.
MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs);
MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs);
if (methodHandle != null)
{
msgMetadata.setSinkClass(PartialStringMessageSink.class);
@ -381,12 +381,9 @@ public abstract class JavaxWebSocketFrameHandlerFactory
// Assemble a list of matching decoders which implement the interface type of the first matching decoder found.
List<RegisteredDecoder> decoders = new ArrayList<>();
Class<? extends Decoder> interfaceType = firstDecoder.interfaceType;
metadata.getAvailableDecoders().stream()
.filter(decoder ->
{
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(decoder.objectType).required()};
return decoder.interfaceType.equals(interfaceType) && (getMethodHandle.apply(args) != null);
}).forEach(decoders::add);
metadata.getAvailableDecoders().stream().filter(decoder ->
decoder.interfaceType.equals(interfaceType) && (getMethodHandle.apply(getArgsFor(decoder.objectType)) != null))
.forEach(decoders::add);
msgMetadata.setRegisteredDecoders(decoders);
// Get the general methodHandle which applies to all the decoders in the list.
@ -395,9 +392,8 @@ public abstract class JavaxWebSocketFrameHandlerFactory
{
if (decoder.objectType.isAssignableFrom(objectType))
objectType = decoder.objectType;
};
InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(objectType).required()};
MethodHandle methodHandle = getMethodHandle.apply(args);
}
MethodHandle methodHandle = getMethodHandle.apply(getArgsFor(objectType));
msgMetadata.setMethodHandle(methodHandle);
// Set the sinkClass and then set the MessageMetadata on the FrameHandlerMetadata

View File

@ -51,7 +51,7 @@ public class DecodedTextMessageSinkTest extends AbstractMessageSinkTest
CompletableFuture<Date> copyFuture = new CompletableFuture<>();
DecodedDateCopy copy = new DecodedDateCopy(copyFuture);
MethodHandle copyHandle = getAcceptHandle(copy, Date.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class);
DecodedTextMessageSink<Calendar> sink = new DecodedTextMessageSink<>(session.getCoreSession(), copyHandle, decoders);
FutureCallback finCallback = new FutureCallback();
@ -69,7 +69,7 @@ public class DecodedTextMessageSinkTest extends AbstractMessageSinkTest
CompletableFuture<Date> copyFuture = new CompletableFuture<>();
DecodedDateCopy copy = new DecodedDateCopy(copyFuture);
MethodHandle copyHandle = getAcceptHandle(copy, Date.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class);
DecodedTextMessageSink<Calendar> sink = new DecodedTextMessageSink<>(session.getCoreSession(), copyHandle, decoders);
FutureCallback callback1 = new FutureCallback();

View File

@ -54,7 +54,7 @@ public class DecodedTextStreamMessageSinkTest extends AbstractMessageSinkTest
CompletableFuture<Date> copyFuture = new CompletableFuture<>();
DecodedDateCopy copy = new DecodedDateCopy(copyFuture);
MethodHandle copyHandle = getAcceptHandle(copy, Date.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class);
DecodedTextStreamMessageSink<Calendar> sink = new DecodedTextStreamMessageSink<>(session.getCoreSession(), copyHandle, decoders);
FutureCallback finCallback = new FutureCallback();
@ -72,7 +72,7 @@ public class DecodedTextStreamMessageSinkTest extends AbstractMessageSinkTest
CompletableFuture<Date> copyFuture = new CompletableFuture<>();
DecodedDateCopy copy = new DecodedDateCopy(copyFuture);
MethodHandle copyHandle = getAcceptHandle(copy, Date.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class);
List<RegisteredDecoder> decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class);
DecodedTextStreamMessageSink<Calendar> sink = new DecodedTextStreamMessageSink<>(session.getCoreSession(), copyHandle, decoders);
FutureCallback callback1 = new FutureCallback();