diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java index e3b462ebac..b37d325ea3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,14 +94,10 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S // @formatter:on } - @SuppressWarnings("unchecked") private Map getAuthorizedClients(WebSession session) { - Map authorizedClients = (session != null) - ? (Map) session.getAttribute(this.sessionAttributeName) : null; - if (authorizedClients == null) { - authorizedClients = new HashMap<>(); - } - return authorizedClients; + Assert.notNull(session, "session cannot be null"); + Map authorizedClients = session.getAttribute(this.sessionAttributeName); + return (authorizedClients != null) ? authorizedClients : new HashMap<>(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java index 55f78d8f04..91c1d4d7a3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web.server; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; @@ -24,10 +25,12 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; /** @@ -202,4 +205,28 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); } + @Test + public void saveAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() { + ServerWebExchange exchange = mock(ServerWebExchange.class); + given(exchange.getSession()).willReturn(Mono.empty()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, exchange).block()) + .withMessage("session cannot be null"); + // @formatter:on + } + + @Test + public void removeAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() { + ServerWebExchange exchange = mock(ServerWebExchange.class); + given(exchange.getSession()).willReturn(Mono.empty()); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, exchange).block()) + .withMessage("session cannot be null"); + // @formatter:on + } + }