mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-02-28 18:39:06 +00:00
Polish DefaultOAuth2AuthorizedClientManager
This commit is contained in:
parent
55f1c695e1
commit
bb8706977d
@ -15,6 +15,13 @@
|
||||
*/
|
||||
package org.springframework.security.oauth2.client.web;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
|
||||
@ -31,13 +38,6 @@ import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* The default implementation of an {@link OAuth2AuthorizedClientManager}.
|
||||
*
|
||||
@ -84,13 +84,13 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
|
||||
if (authorizedClient != null) {
|
||||
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
||||
} else {
|
||||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
||||
Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
|
||||
authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
||||
clientRegistrationId, principal, servletRequest);
|
||||
if (authorizedClient != null) {
|
||||
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
||||
} else {
|
||||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
||||
Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
|
||||
contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,18 @@
|
||||
*/
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import org.springframework.core.codec.ByteBufferEncoder;
|
||||
import org.springframework.core.codec.CharSequenceEncoder;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.BodyInserter;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
|
||||
|
||||
/**
|
||||
* @author Rob Winch
|
||||
@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
|
||||
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
|
||||
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration);
|
||||
|
||||
// Default request attributes set
|
||||
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
|
||||
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
|
||||
|
Loading…
x
Reference in New Issue
Block a user