diff --git a/spring-websockets/pom.xml b/spring-websockets/pom.xml index a28ef8749a..28c875d50d 100644 --- a/spring-websockets/pom.xml +++ b/spring-websockets/pom.xml @@ -1,7 +1,7 @@ + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 spring-websockets spring-websockets diff --git a/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StockTicksController.java b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StockTicksController.java new file mode 100644 index 0000000000..0942657c33 --- /dev/null +++ b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StockTicksController.java @@ -0,0 +1,39 @@ +package com.baeldung.debugwebsockets; + +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Controller; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; + +@Controller +public class StockTicksController { + private final SimpMessagingTemplate simpMessagingTemplate; + + public StockTicksController(SimpMessagingTemplate simpMessagingTemplate) { + this.simpMessagingTemplate = simpMessagingTemplate; + } + + @Scheduled(fixedRate = 3000) + public void sendTicks() { + simpMessagingTemplate.convertAndSend("/topic/ticks", getStockTicks()); + } + + private Map getStockTicks() { + Map ticks = new HashMap<>(); + ticks.put("AAPL", getRandomTick()); + ticks.put("GOOGL", getRandomTick()); + ticks.put("MSFT", getRandomTick()); + ticks.put("TSLA", getRandomTick()); + ticks.put("AMZN", getRandomTick()); + ticks.put("HPE", getRandomTick()); + + return ticks; + } + + private int getRandomTick() { + return ThreadLocalRandom.current().nextInt(-100, 100 + 1); + } +} \ No newline at end of file diff --git a/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompClientSessionHandler.java b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompClientSessionHandler.java new file mode 100644 index 0000000000..535be79cee --- /dev/null +++ b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompClientSessionHandler.java @@ -0,0 +1,31 @@ +package com.baeldung.debugwebsockets; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.messaging.simp.stomp.StompHeaders; +import org.springframework.messaging.simp.stomp.StompSession; +import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; + +import java.lang.reflect.Type; +import java.util.Map; + +public class StompClientSessionHandler extends StompSessionHandlerAdapter { + private static final Logger logger = LoggerFactory.getLogger("StompClientSessionHandler"); + + @Override + public void afterConnected(StompSession session, StompHeaders connectedHeaders) { + logger.info("New session established. Session Id -> {}", session.getSessionId()); + session.subscribe("/topic/ticks", this); + logger.info("Subscribed to topic: /topic/ticks"); + } + + @Override + public void handleFrame(StompHeaders headers, Object payload) { + logger.info("Payload -> {}", payload); + } + + @Override + public Type getPayloadType(StompHeaders headers) { + return Map.class; + } +} diff --git a/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompWebSocketClient.java b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompWebSocketClient.java new file mode 100644 index 0000000000..0cbe32bf65 --- /dev/null +++ b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/StompWebSocketClient.java @@ -0,0 +1,24 @@ +package com.baeldung.debugwebsockets; + +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.simp.stomp.StompSessionHandler; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.client.standard.StandardWebSocketClient; +import org.springframework.web.socket.messaging.WebSocketStompClient; + +import java.util.Scanner; + +public class StompWebSocketClient { + + private static final String URL = "ws://localhost:8080/stock-ticks/websocket"; + + public static void main(String[] args) { + WebSocketClient client = new StandardWebSocketClient(); + WebSocketStompClient stompClient = new WebSocketStompClient(client); + stompClient.setMessageConverter(new MappingJackson2MessageConverter()); + StompSessionHandler sessionHandler = new StompClientSessionHandler(); + stompClient.connect(URL, sessionHandler); + + new Scanner(System.in).nextLine(); + } +} diff --git a/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketApplication.java b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketApplication.java new file mode 100644 index 0000000000..1d0d6950d3 --- /dev/null +++ b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketApplication.java @@ -0,0 +1,13 @@ +package com.baeldung.debugwebsockets; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class WebsocketApplication { + + public static void main(String[] args) { + SpringApplication.run(WebsocketApplication.class, args); + } + +} diff --git a/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketConfiguration.java b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketConfiguration.java new file mode 100644 index 0000000000..3735e7359b --- /dev/null +++ b/spring-websockets/src/main/java/com/baeldung/debugwebsockets/WebsocketConfiguration.java @@ -0,0 +1,25 @@ +package com.baeldung.debugwebsockets; + +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.scheduling.annotation.EnableScheduling; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; + +@Configuration +@EnableWebSocketMessageBroker +@EnableScheduling +public class WebsocketConfiguration implements WebSocketMessageBrokerConfigurer { + @Override + public void configureMessageBroker(MessageBrokerRegistry config) { + config.enableSimpleBroker("/topic"); + config.setApplicationDestinationPrefixes("/app"); + } + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.addEndpoint("/stock-ticks").setAllowedOriginPatterns("*").withSockJS(); + } + +} diff --git a/spring-websockets/src/test/java/com/baeldung/debugwebsockets/WebSocketIntegrationTest.java b/spring-websockets/src/test/java/com/baeldung/debugwebsockets/WebSocketIntegrationTest.java new file mode 100644 index 0000000000..bdc283b9e4 --- /dev/null +++ b/spring-websockets/src/test/java/com/baeldung/debugwebsockets/WebSocketIntegrationTest.java @@ -0,0 +1,114 @@ +package com.baeldung.debugwebsockets; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.web.server.LocalServerPort; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompFrameHandler; +import org.springframework.messaging.simp.stomp.StompHeaders; +import org.springframework.messaging.simp.stomp.StompSession; +import org.springframework.messaging.simp.stomp.StompSessionHandler; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.client.standard.StandardWebSocketClient; +import org.springframework.web.socket.messaging.WebSocketStompClient; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +/** + * This should be part of integration test suite. + * The test starts the server and then connects to the WebSocket. Then verifies if the messages are received from the + * WebSocket. + * This test is inspired from: https://github.com/spring-guides/gs-messaging-stomp-websocket/blob/main/complete/src/test/java/com/example/messagingstompwebsocket/GreetingIntegrationTests.java + */ +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +class WebSocketIntegrationTest{ + WebSocketClient client; + WebSocketStompClient stompClient; + @LocalServerPort + private int port; + private static final Logger logger= LoggerFactory.getLogger(WebSocketIntegrationTest.class); + + @BeforeEach + public void setup() { + logger.info("Setting up the tests ..."); + client = new StandardWebSocketClient(); + stompClient = new WebSocketStompClient(client); + stompClient.setMessageConverter(new MappingJackson2MessageConverter()); + } + + @Test + void givenWebSocket_whenMessage_thenVerifyMessage() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference failure = new AtomicReference<>(); + StompSessionHandler sessionHandler = new StompSessionHandler() { + @Override + public Type getPayloadType(StompHeaders headers) { + return null; + } + + @Override + public void handleFrame(StompHeaders headers, Object payload) { + } + + @Override + public void afterConnected(StompSession session, StompHeaders connectedHeaders) { + logger.info("Connected to the WebSocket ..."); + session.subscribe("/topic/ticks", new StompFrameHandler() { + @Override + public Type getPayloadType(StompHeaders headers) { + return Map.class; + } + + @Override + public void handleFrame(StompHeaders headers, Object payload) { + try { + + assertThat(payload).isNotNull(); + assertThat(payload).isInstanceOf(Map.class); + + @SuppressWarnings("unchecked") + Map map = (Map) payload; + + assertThat(map).containsKey("HPE"); + assertThat(map.get("HPE")).isInstanceOf(Integer.class); + } catch (Throwable t) { + failure.set(t); + logger.error("There is an exception ", t); + } finally { + session.disconnect(); + latch.countDown(); + } + + } + }); + } + + @Override + public void handleException(StompSession session, StompCommand command, StompHeaders headers, byte[] payload, Throwable exception) { + } + + @Override + public void handleTransportError(StompSession session, Throwable exception) { + } + }; + stompClient.connect("ws://localhost:{port}/stock-ticks/websocket", sessionHandler, this.port); + if (latch.await(20, TimeUnit.SECONDS)) { + if (failure.get() != null) { + fail("Assertion Failed", failure.get()); + } + } else { + fail("Could not receive the message on time"); + } + } +}