From 1695d03b72d9d48f24abd8381f5f120b99a7f7a1 Mon Sep 17 00:00:00 2001 From: JANG Date: Sun, 28 Apr 2024 01:33:15 +0900 Subject: [PATCH] Assert WebSession is not null Issue gh-14975 --- ...erverOAuth2AuthorizedClientRepository.java | 11 ++++----- ...OAuth2AuthorizedClientRepositoryTests.java | 24 ++++++++++++++++++- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java index e3b462ebac..344589404a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -96,12 +96,9 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S @SuppressWarnings("unchecked") private Map getAuthorizedClients(WebSession session) { - Map authorizedClients = (session != null) - ? (Map) session.getAttribute(this.sessionAttributeName) : null; - if (authorizedClients == null) { - authorizedClients = new HashMap<>(); - } - return authorizedClients; + Assert.notNull(session, "session cannot be null"); + Map authorizedClients = session.getAttribute(this.sessionAttributeName); + return (authorizedClients != null) ? authorizedClients : new HashMap<>(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java index 55f78d8f04..4848be2b20 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,10 +25,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -201,5 +203,25 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { assertThat(loadedAuthorizedClient2).isNotNull(); assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); } + + @Test + public void saveAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() { + MockServerWebExchange mockedExchange = mock(MockServerWebExchange.class); + when(mockedExchange.getSession()).thenReturn(Mono.empty()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); + assertThatIllegalArgumentException().isThrownBy( + () -> authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, mockedExchange).block()) + .withMessage("session cannot be null"); + } + + @Test + public void removeAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() { + MockServerWebExchange mockedExchange = mock(MockServerWebExchange.class); + when(mockedExchange.getSession()).thenReturn(Mono.empty()); + assertThatIllegalArgumentException().isThrownBy( + () -> authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, mockedExchange).block()) + .withMessage("session cannot be null"); + } }