diff --git a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java index f89d7bb8b0a..9ba18755761 100644 --- a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java +++ b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java @@ -19,11 +19,9 @@ package org.eclipse.jetty.websocket.javax.common; import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -368,54 +366,32 @@ public class JavaxWebSocketFrameHandler implements FrameHandler return textMetadata; } - private void assertBasicTypeNotRegistered(byte basicWebSocketType, Object messageImpl, String replacement) - { - if (messageImpl != null) - { - throw new IllegalStateException( - "Cannot register " + replacement + ": Basic WebSocket type " + OpCode.name(basicWebSocketType) + " is already registered"); - } - } - public void addMessageHandler(Class clazz, MessageHandler.Partial handler) { try { - MethodHandles.Lookup lookup = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup(); - MethodHandle partialMessageHandler = lookup - .findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class)); - partialMessageHandler = partialMessageHandler.bindTo(handler); + MethodHandle methodHandle = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup() + .findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class)) + .bindTo(handler); + JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); + metadata.setMethodHandle(methodHandle); + byte basicType; // MessageHandler.Partial has no decoder support! if (byte[].class.isAssignableFrom(clazz)) { - assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); - MessageSink messageSink = new PartialByteArrayMessageSink(coreSession, partialMessageHandler); - this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); - JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); - metadata.setMethodHandle(partialMessageHandler); + basicType = OpCode.BINARY; metadata.setSinkClass(PartialByteArrayMessageSink.class); - this.binaryMetadata = metadata; } else if (ByteBuffer.class.isAssignableFrom(clazz)) { - assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); - MessageSink messageSink = new PartialByteBufferMessageSink(coreSession, partialMessageHandler); - this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); - JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); - metadata.setMethodHandle(partialMessageHandler); + basicType = OpCode.BINARY; metadata.setSinkClass(PartialByteBufferMessageSink.class); - this.binaryMetadata = metadata; } else if (String.class.isAssignableFrom(clazz)) { - assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName()); - MessageSink messageSink = new PartialStringMessageSink(coreSession, partialMessageHandler); - this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); - JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); - metadata.setMethodHandle(partialMessageHandler); + basicType = OpCode.TEXT; metadata.setSinkClass(PartialStringMessageSink.class); - this.textMetadata = metadata; } else { @@ -423,6 +399,9 @@ public class JavaxWebSocketFrameHandler implements FrameHandler "Unable to add " + handler.getClass().getName() + " with type " + clazz + ": only supported types byte[], " + ByteBuffer.class.getName() + ", " + String.class.getName()); } + + // Register the Metadata as a MessageHandler. + registerMessageHandler(clazz, handler, basicType, metadata); } catch (NoSuchMethodException e) { @@ -444,68 +423,52 @@ public class JavaxWebSocketFrameHandler implements FrameHandler if (PongMessage.class.isAssignableFrom(clazz)) { - assertBasicTypeNotRegistered(OpCode.PONG, this.pongHandle, handler.getClass().getName()); + assertBasicTypeNotRegistered(OpCode.PONG, handler); this.pongHandle = methodHandle; registerMessageHandler(OpCode.PONG, clazz, handler, null); + return; + } + + AvailableDecoders availableDecoders = session.getDecoders(); + RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz); + if (registeredDecoder == null) + throw new IllegalStateException("Unable to find Decoder for type: " + clazz); + + // Create the message metadata specific to the MessageHandler type. + JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); + metadata.setMethodHandle(methodHandle); + byte basicType; + if (registeredDecoder.implementsInterface(Decoder.Binary.class)) + { + basicType = OpCode.BINARY; + metadata.setRegisteredDecoders(availableDecoders.getBinaryDecoders(clazz)); + metadata.setSinkClass(DecodedBinaryMessageSink.class); + } + else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class)) + { + basicType = OpCode.BINARY; + metadata.setRegisteredDecoders(availableDecoders.getBinaryStreamDecoders(clazz)); + metadata.setSinkClass(DecodedBinaryStreamMessageSink.class); + } + else if (registeredDecoder.implementsInterface(Decoder.Text.class)) + { + basicType = OpCode.TEXT; + metadata.setRegisteredDecoders(availableDecoders.getTextDecoders(clazz)); + metadata.setSinkClass(DecodedTextMessageSink.class); + } + else if (registeredDecoder.implementsInterface(Decoder.TextStream.class)) + { + basicType = OpCode.TEXT; + metadata.setRegisteredDecoders(availableDecoders.getTextStreamDecoders(clazz)); + metadata.setSinkClass(DecodedTextStreamMessageSink.class); } else { - AvailableDecoders availableDecoders = session.getDecoders(); - - RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz); - if (registeredDecoder == null) - { - throw new IllegalStateException("Unable to find Decoder for type: " + clazz); - } - - JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata(); - metadata.setMethodHandle(methodHandle); - - if (registeredDecoder.implementsInterface(Decoder.Binary.class)) - { - assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); - List binaryDecoders = availableDecoders.getBinaryDecoders(clazz); - metadata.setRegisteredDecoders(binaryDecoders); - MessageSink messageSink = new DecodedBinaryMessageSink(coreSession, methodHandle, binaryDecoders); - metadata.setSinkClass(messageSink.getClass()); - this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); - this.binaryMetadata = metadata; - } - else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class)) - { - assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName()); - List binaryStreamDecoders = availableDecoders.getBinaryStreamDecoders(clazz); - metadata.setRegisteredDecoders(binaryStreamDecoders); - MessageSink messageSink = new DecodedBinaryStreamMessageSink(coreSession, methodHandle, binaryStreamDecoders); - metadata.setSinkClass(messageSink.getClass()); - this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); - this.binaryMetadata = metadata; - } - else if (registeredDecoder.implementsInterface(Decoder.Text.class)) - { - assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName()); - List textDecoders = availableDecoders.getTextDecoders(clazz); - metadata.setRegisteredDecoders(textDecoders); - MessageSink messageSink = new DecodedTextMessageSink(coreSession, methodHandle, textDecoders); - metadata.setSinkClass(messageSink.getClass()); - this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); - this.textMetadata = metadata; - } - else if (registeredDecoder.implementsInterface(Decoder.TextStream.class)) - { - assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName()); - List textStreamDecoders = availableDecoders.getTextStreamDecoders(clazz); - metadata.setRegisteredDecoders(textStreamDecoders); - MessageSink messageSink = new DecodedTextStreamMessageSink(coreSession, methodHandle, textStreamDecoders); - metadata.setSinkClass(messageSink.getClass()); - this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); - this.textMetadata = metadata; - } - else - { - throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + clazz + " is unrecognized by declared decoders"); - } + throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + clazz + " is unrecognized by declared decoders"); } + + // Register the Metadata as a MessageHandler. + registerMessageHandler(clazz, handler, basicType, metadata); } catch (NoSuchMethodException e) { @@ -517,6 +480,50 @@ public class JavaxWebSocketFrameHandler implements FrameHandler } } + private void assertBasicTypeNotRegistered(byte basicWebSocketType, MessageHandler replacement) + { + Object messageImpl; + switch (basicWebSocketType) + { + case OpCode.TEXT: + messageImpl = textSink; + break; + case OpCode.BINARY: + messageImpl = binarySink; + break; + case OpCode.PONG: + messageImpl = pongHandle; + break; + default: + throw new IllegalStateException(); + } + + if (messageImpl != null) + { + throw new IllegalStateException("Cannot register " + replacement.getClass().getName() + + ": Basic WebSocket type " + OpCode.name(basicWebSocketType) + " is already registered"); + } + } + + private void registerMessageHandler(Class clazz, MessageHandler handler, byte basicMessageType, JavaxWebSocketMessageMetadata metadata) + { + assertBasicTypeNotRegistered(basicMessageType, handler); + MessageSink messageSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, metadata); + switch (basicMessageType) + { + case OpCode.TEXT: + this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink); + this.textMetadata = metadata; + break; + case OpCode.BINARY: + this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink); + this.binaryMetadata = metadata; + break; + default: + throw new IllegalStateException(); + } + } + private MessageSink registerMessageHandler(byte basicWebSocketMessageType, Class handlerType, MessageHandler handler, MessageSink messageSink) { synchronized (messageHandlerMap) @@ -563,7 +570,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler this.binarySink = null; break; default: - break; // TODO ISE? + throw new IllegalStateException(); } } } diff --git a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java index 972cd5f24c1..ed05db4385a 100644 --- a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java +++ b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java @@ -327,7 +327,7 @@ public abstract class JavaxWebSocketFrameHandlerFactory Function getMethodHandle) { // Partial Text Message. - MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs); + MethodHandle methodHandle = getMethodHandle.apply(JavaxWebSocketCallingArgs.textPartialCallingArgs); if (methodHandle != null) { msgMetadata.setSinkClass(PartialStringMessageSink.class); @@ -381,12 +381,9 @@ public abstract class JavaxWebSocketFrameHandlerFactory // Assemble a list of matching decoders which implement the interface type of the first matching decoder found. List decoders = new ArrayList<>(); Class interfaceType = firstDecoder.interfaceType; - metadata.getAvailableDecoders().stream() - .filter(decoder -> - { - InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(decoder.objectType).required()}; - return decoder.interfaceType.equals(interfaceType) && (getMethodHandle.apply(args) != null); - }).forEach(decoders::add); + metadata.getAvailableDecoders().stream().filter(decoder -> + decoder.interfaceType.equals(interfaceType) && (getMethodHandle.apply(getArgsFor(decoder.objectType)) != null)) + .forEach(decoders::add); msgMetadata.setRegisteredDecoders(decoders); // Get the general methodHandle which applies to all the decoders in the list. @@ -395,9 +392,8 @@ public abstract class JavaxWebSocketFrameHandlerFactory { if (decoder.objectType.isAssignableFrom(objectType)) objectType = decoder.objectType; - }; - InvokerUtils.Arg[] args = {new InvokerUtils.Arg(Session.class), new InvokerUtils.Arg(objectType).required()}; - MethodHandle methodHandle = getMethodHandle.apply(args); + } + MethodHandle methodHandle = getMethodHandle.apply(getArgsFor(objectType)); msgMetadata.setMethodHandle(methodHandle); // Set the sinkClass and then set the MessageMetadata on the FrameHandlerMetadata diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextMessageSinkTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextMessageSinkTest.java index 4a4a5e2012a..a915adfd7ef 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextMessageSinkTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextMessageSinkTest.java @@ -51,7 +51,7 @@ public class DecodedTextMessageSinkTest extends AbstractMessageSinkTest CompletableFuture copyFuture = new CompletableFuture<>(); DecodedDateCopy copy = new DecodedDateCopy(copyFuture); MethodHandle copyHandle = getAcceptHandle(copy, Date.class); - List decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class); + List decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class); DecodedTextMessageSink sink = new DecodedTextMessageSink<>(session.getCoreSession(), copyHandle, decoders); FutureCallback finCallback = new FutureCallback(); @@ -69,7 +69,7 @@ public class DecodedTextMessageSinkTest extends AbstractMessageSinkTest CompletableFuture copyFuture = new CompletableFuture<>(); DecodedDateCopy copy = new DecodedDateCopy(copyFuture); MethodHandle copyHandle = getAcceptHandle(copy, Date.class); - List decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class); + List decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class); DecodedTextMessageSink sink = new DecodedTextMessageSink<>(session.getCoreSession(), copyHandle, decoders); FutureCallback callback1 = new FutureCallback(); diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextStreamMessageSinkTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextStreamMessageSinkTest.java index 051a91b4595..fe35b4c80b9 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextStreamMessageSinkTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/DecodedTextStreamMessageSinkTest.java @@ -54,7 +54,7 @@ public class DecodedTextStreamMessageSinkTest extends AbstractMessageSinkTest CompletableFuture copyFuture = new CompletableFuture<>(); DecodedDateCopy copy = new DecodedDateCopy(copyFuture); MethodHandle copyHandle = getAcceptHandle(copy, Date.class); - List decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class); + List decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class); DecodedTextStreamMessageSink sink = new DecodedTextStreamMessageSink<>(session.getCoreSession(), copyHandle, decoders); FutureCallback finCallback = new FutureCallback(); @@ -72,7 +72,7 @@ public class DecodedTextStreamMessageSinkTest extends AbstractMessageSinkTest CompletableFuture copyFuture = new CompletableFuture<>(); DecodedDateCopy copy = new DecodedDateCopy(copyFuture); MethodHandle copyHandle = getAcceptHandle(copy, Date.class); - List decoders = toRegisteredDecoderList(DecodedBinaryStreamMessageSinkTest.GmtDecoder.class, Calendar.class); + List decoders = toRegisteredDecoderList(GmtDecoder.class, Calendar.class); DecodedTextStreamMessageSink sink = new DecodedTextStreamMessageSink<>(session.getCoreSession(), copyHandle, decoders); FutureCallback callback1 = new FutureCallback();