mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-03-01 02:49:11 +00:00
Polish DefaultOAuth2AuthorizedClientManager
This commit is contained in:
parent
55f1c695e1
commit
bb8706977d
@ -15,6 +15,13 @@
|
|||||||
*/
|
*/
|
||||||
package org.springframework.security.oauth2.client.web;
|
package org.springframework.security.oauth2.client.web;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
import org.springframework.lang.Nullable;
|
import org.springframework.lang.Nullable;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
|
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
|
||||||
@ -31,13 +38,6 @@ import org.springframework.util.StringUtils;
|
|||||||
import org.springframework.web.context.request.RequestContextHolder;
|
import org.springframework.web.context.request.RequestContextHolder;
|
||||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.function.Function;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The default implementation of an {@link OAuth2AuthorizedClientManager}.
|
* The default implementation of an {@link OAuth2AuthorizedClientManager}.
|
||||||
*
|
*
|
||||||
@ -84,13 +84,13 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
|
|||||||
if (authorizedClient != null) {
|
if (authorizedClient != null) {
|
||||||
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
||||||
} else {
|
} else {
|
||||||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
|
||||||
Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
|
|
||||||
authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
||||||
clientRegistrationId, principal, servletRequest);
|
clientRegistrationId, principal, servletRequest);
|
||||||
if (authorizedClient != null) {
|
if (authorizedClient != null) {
|
||||||
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
|
||||||
} else {
|
} else {
|
||||||
|
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
||||||
|
Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
|
||||||
contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
|
contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,18 @@
|
|||||||
*/
|
*/
|
||||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
|
|||||||
import org.mockito.Captor;
|
import org.mockito.Captor;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.MockitoJUnitRunner;
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import reactor.util.context.Context;
|
||||||
|
|
||||||
import org.springframework.core.codec.ByteBufferEncoder;
|
import org.springframework.core.codec.ByteBufferEncoder;
|
||||||
import org.springframework.core.codec.CharSequenceEncoder;
|
import org.springframework.core.codec.CharSequenceEncoder;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
|
|||||||
import org.springframework.web.reactive.function.BodyInserter;
|
import org.springframework.web.reactive.function.BodyInserter;
|
||||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||||
import org.springframework.web.reactive.function.client.WebClient;
|
import org.springframework.web.reactive.function.client.WebClient;
|
||||||
import reactor.util.context.Context;
|
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.servlet.http.HttpServletResponse;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.time.Duration;
|
|
||||||
import java.time.Instant;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.any;
|
||||||
|
import static org.mockito.Mockito.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
import static org.springframework.http.HttpMethod.GET;
|
import static org.springframework.http.HttpMethod.GET;
|
||||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
||||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse;
|
||||||
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||||||
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
|
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
|
||||||
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
|
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
|
||||||
|
|
||||||
when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration);
|
|
||||||
|
|
||||||
// Default request attributes set
|
// Default request attributes set
|
||||||
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
|
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
|
||||||
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
|
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user