diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java index 1b5067c576..940d685eaf 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,8 @@ import org.springframework.security.config.annotation.web.WebSecurityConfigurer; @Target(value = { java.lang.annotation.ElementType.TYPE }) @Documented @Import({ WebSecurityConfiguration.class, - SpringWebMvcImportSelector.class }) + SpringWebMvcImportSelector.class, + OAuth2ImportSelector.class }) @EnableGlobalAuthentication @Configuration public @interface EnableWebSecurity { @@ -83,4 +84,4 @@ public @interface EnableWebSecurity { * @return if true, enables debug support with Spring Security */ boolean debug() default false; -} \ No newline at end of file +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java new file mode 100644 index 0000000000..e0ff29b599 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.web.configuration; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.context.annotation.ImportSelector; +import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.method.annotation.OAuth2ClientArgumentResolver; +import org.springframework.util.ClassUtils; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import java.util.List; + +/** + * {@link Configuration} for OAuth 2.0 Client support. + * + *
+ * This {@code Configuration} is conditionally imported by {@link OAuth2ImportSelector}
+ * when the {@code spring-security-oauth2-client} module is present on the classpath.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2ImportSelector
+ */
+@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class)
+final class OAuth2ClientConfiguration {
+
+ static class OAuth2ClientWebMvcImportSelector implements ImportSelector {
+
+ @Override
+ public String[] selectImports(AnnotationMetadata importingClassMetadata) {
+ boolean webmvcPresent = ClassUtils.isPresent(
+ "org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader());
+
+ return webmvcPresent ?
+ new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } :
+ new String[] {};
+ }
+ }
+
+ @Configuration
+ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
+ @Autowired(required = false)
+ private ClientRegistrationRepository clientRegistrationRepository;
+
+ @Autowired(required = false)
+ private OAuth2AuthorizedClientService authorizedClientService;
+
+ @Override
+ public void addArgumentResolvers(List
+ * For example:
+ *
+ * For example:
+ *
+ * @Controller
+ * public class MyController {
+ * @GetMapping("/client-registration")
+ * public String clientRegistration(@OAuth2Client("login-client") ClientRegistration clientRegistration) {
+ * // do something with clientRegistration
+ * }
+ *
+ * @GetMapping("/authorized-client")
+ * public String authorizedClient(@OAuth2Client("login-client") OAuth2AuthorizedClient authorizedClient) {
+ * // do something with authorizedClient
+ * }
+ *
+ * @GetMapping("/access-token")
+ * public String accessToken(@OAuth2Client("login-client") OAuth2AccessToken accessToken) {
+ * // do something with accessToken
+ * }
+ * }
+ *
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2ClientArgumentResolver
+ */
+@Target({ ElementType.PARAMETER, ElementType.ANNOTATION_TYPE })
+@Retention(RetentionPolicy.RUNTIME)
+@Documented
+public @interface OAuth2Client {
+
+ /**
+ * Sets the client registration identifier.
+ *
+ * @return the client registration identifier
+ */
+ @AliasFor("value")
+ String registrationId() default "";
+
+ /**
+ * The default attribute for this annotation.
+ * This is an alias for {@link #registrationId()}.
+ * For example, {@code @OAuth2Client("login-client")} is equivalent to
+ * {@code @OAuth2Client(registrationId="login-client")}.
+ *
+ * @return the client registration identifier
+ */
+ @AliasFor("registrationId")
+ String value() default "";
+
+}
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java
index 89785c76d6..78ba725505 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java
@@ -93,8 +93,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private AuthorizationRequestRepository
+ * @Controller
+ * public class MyController {
+ * @GetMapping("/client-registration")
+ * public String clientRegistration(@OAuth2Client("login-client") ClientRegistration clientRegistration) {
+ * // do something with clientRegistration
+ * }
+ *
+ * @GetMapping("/authorized-client")
+ * public String authorizedClient(@OAuth2Client("login-client") OAuth2AuthorizedClient authorizedClient) {
+ * // do something with authorizedClient
+ * }
+ *
+ * @GetMapping("/access-token")
+ * public String accessToken(@OAuth2Client("login-client") OAuth2AccessToken accessToken) {
+ * // do something with accessToken
+ * }
+ * }
+ *
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2Client
+ */
+public final class OAuth2ClientArgumentResolver implements HandlerMethodArgumentResolver {
+ private final ClientRegistrationRepository clientRegistrationRepository;
+ private final OAuth2AuthorizedClientService authorizedClientService;
+
+ /**
+ * Constructs an {@code OAuth2ClientArgumentResolver} using the provided parameters.
+ *
+ * @param clientRegistrationRepository the repository of client registrations
+ * @param authorizedClientService the authorized client service
+ */
+ public OAuth2ClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository,
+ OAuth2AuthorizedClientService authorizedClientService) {
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+ this.clientRegistrationRepository = clientRegistrationRepository;
+ this.authorizedClientService = authorizedClientService;
+ }
+
+ @Override
+ public boolean supportsParameter(MethodParameter parameter) {
+ Class> parameterType = parameter.getParameterType();
+ return ((OAuth2AccessToken.class.isAssignableFrom(parameterType) ||
+ OAuth2AuthorizedClient.class.isAssignableFrom(parameterType) ||
+ ClientRegistration.class.isAssignableFrom(parameterType)) &&
+ (parameter.hasParameterAnnotation(OAuth2Client.class)));
+ }
+
+ @NonNull
+ @Override
+ public Object resolveArgument(MethodParameter parameter,
+ @Nullable ModelAndViewContainer mavContainer,
+ NativeWebRequest webRequest,
+ @Nullable WebDataBinderFactory binderFactory) throws Exception {
+
+ OAuth2Client oauth2ClientAnnotation = parameter.getParameterAnnotation(OAuth2Client.class);
+ Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+
+ String clientRegistrationId = null;
+ if (!StringUtils.isEmpty(oauth2ClientAnnotation.registrationId())) {
+ clientRegistrationId = oauth2ClientAnnotation.registrationId();
+ } else if (!StringUtils.isEmpty(oauth2ClientAnnotation.value())) {
+ clientRegistrationId = oauth2ClientAnnotation.value();
+ } else if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) {
+ clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
+ }
+ if (StringUtils.isEmpty(clientRegistrationId)) {
+ throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " +
+ "It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\").");
+ }
+
+ if (ClientRegistration.class.isAssignableFrom(parameter.getParameterType())) {
+ ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
+ if (clientRegistration == null) {
+ throw new IllegalArgumentException("Unable to find ClientRegistration with registration identifier \"" +
+ clientRegistrationId + "\".");
+ }
+ return clientRegistration;
+ }
+
+ if (principal == null) {
+ // An Authentication is required given that an OAuth2AuthorizedClient is associated to a Principal
+ throw new IllegalStateException("Unable to resolve the Authorized Client with registration identifier \"" +
+ clientRegistrationId + "\". An \"authenticated\" or \"unauthenticated\" session is required. " +
+ "To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
+ }
+
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+ clientRegistrationId, principal.getName());
+ if (authorizedClient == null) {
+ throw new ClientAuthorizationRequiredException(clientRegistrationId);
+ }
+
+ return OAuth2AccessToken.class.isAssignableFrom(parameter.getParameterType()) ?
+ authorizedClient.getAccessToken() : authorizedClient;
+ }
+}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java
index cf9fdf95fa..3412c852d1 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java
@@ -28,15 +28,13 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.web.savedrequest.SavedRequest;
+import org.springframework.security.web.savedrequest.RequestCache;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
-import javax.servlet.http.HttpSession;
-import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -53,6 +51,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
private ClientRegistration registration3;
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizationRequestRedirectFilter filter;
+ private RequestCache requestCache;
@Before
public void setUp() {
@@ -95,6 +94,8 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
this.registration1, this.registration2, this.registration3);
this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository);
+ this.requestCache = mock(RequestCache.class);
+ this.filter.setRequestCache(this.requestCache);
}
@Test
@@ -115,6 +116,12 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
.isInstanceOf(IllegalArgumentException.class);
}
+ @Test
+ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> this.filter.setRequestCache(null))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
@Test
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path";
@@ -129,7 +136,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
}
@Test
- public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusBadRequest() throws Exception {
+ public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -141,8 +148,8 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
verifyZeroInteractions(filterChain);
- assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
- assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
+ assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
@Test
@@ -320,16 +327,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
- HttpSession session = request.getSession(false);
- assertThat(session).isNotNull();
- boolean requestSaved = false;
- for (String attrName : Collections.list(session.getAttributeNames())) {
- if (SavedRequest.class.isAssignableFrom(session.getAttribute(attrName).getClass())) {
- requestSaved = true;
- break;
- }
- }
- assertThat(requestSaved).isTrue();
+ verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2ClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2ClientArgumentResolverTests.java
new file mode 100644
index 0000000000..5ad0fed157
--- /dev/null
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2ClientArgumentResolverTests.java
@@ -0,0 +1,260 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web.method.annotation;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.core.MethodParameter;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.OAuth2Client;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.util.ReflectionUtils;
+
+import java.lang.reflect.Method;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
+import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2ClientArgumentResolver}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2ClientArgumentResolverTests {
+ private ClientRegistrationRepository clientRegistrationRepository;
+ private OAuth2AuthorizedClientService authorizedClientService;
+ private OAuth2ClientArgumentResolver argumentResolver;
+ private ClientRegistration clientRegistration;
+ private OAuth2AuthorizedClient authorizedClient;
+ private OAuth2AccessToken accessToken;
+
+ @Before
+ public void setUp() {
+ this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+ this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
+ this.argumentResolver = new OAuth2ClientArgumentResolver(
+ this.clientRegistrationRepository, this.authorizedClientService);
+ this.clientRegistration = ClientRegistration.withRegistrationId("client1")
+ .clientId("client-id")
+ .clientSecret("secret")
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+ .redirectUriTemplate("{baseUrl}/client1")
+ .scope("scope1", "scope2")
+ .authorizationUri("https://provider.com/oauth2/auth")
+ .tokenUri("https://provider.com/oauth2/token")
+ .clientName("Client 1")
+ .build();
+ when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(this.clientRegistration);
+ this.authorizedClient = mock(OAuth2AuthorizedClient.class);
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(this.authorizedClient);
+ this.accessToken = mock(OAuth2AccessToken.class);
+ when(this.authorizedClient.getAccessToken()).thenReturn(this.accessToken);
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+ securityContext.setAuthentication(mock(Authentication.class));
+ SecurityContextHolder.setContext(securityContext);
+ }
+
+ @Test
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(null, this.authorizedClientService))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(this.clientRegistrationRepository, null))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AccessTokenThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AccessTokenWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenWithoutAnnotation", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", OAuth2AuthorizedClient.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeClientRegistrationThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeClientRegistrationWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistrationWithoutAnnotation", ClientRegistration.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeUnsupportedThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupported", String.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeUnsupportedWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", String.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() throws Exception {
+ MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class);
+ assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\").");
+ }
+
+ @Test
+ public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() throws Exception {
+ OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
+ when(authentication.getAuthorizedClientRegistrationId()).thenReturn("client1");
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+ securityContext.setAuthentication(authentication);
+ SecurityContextHolder.setContext(securityContext);
+ MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class);
+ this.argumentResolver.resolveArgument(methodParameter, null, null, null);
+ }
+
+ @Test
+ public void resolveArgumentWhenClientRegistrationFoundThenResolves() throws Exception {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.clientRegistration);
+ }
+
+ @Test
+ public void resolveArgumentWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() throws Exception {
+ when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(null);
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unable to find ClientRegistration with registration identifier \"client1\".");
+ }
+
+ @Test
+ public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() throws Exception {
+ SecurityContextHolder.clearContext();
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " +
+ "An \"authenticated\" or \"unauthenticated\" session is required. " +
+ "To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(null);
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+ .isInstanceOf(ClientAuthorizationRequiredException.class);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientFoundThenResolves() throws Exception {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient.getAccessToken());
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(null);
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+ .isInstanceOf(ClientAuthorizationRequiredException.class);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndAnnotationRegistrationIdSetThenResolves() throws Exception {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenAnnotationRegistrationId", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient.getAccessToken());
+ }
+
+ private MethodParameter getMethodParameter(String methodName, Class>... paramTypes) {
+ Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
+ return new MethodParameter(method, 0);
+ }
+
+ static class TestController {
+ void paramTypeAccessToken(@OAuth2Client("client1") OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAccessTokenWithoutAnnotation(OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAuthorizedClient(@OAuth2Client("client1") OAuth2AuthorizedClient authorizedClient) {
+ }
+
+ void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) {
+ }
+
+ void paramTypeClientRegistration(@OAuth2Client("client1") ClientRegistration clientRegistration) {
+ }
+
+ void paramTypeClientRegistrationWithoutAnnotation(ClientRegistration clientRegistration) {
+ }
+
+ void paramTypeUnsupported(@OAuth2Client("client1") String param) {
+ }
+
+ void paramTypeUnsupportedWithoutAnnotation(String param) {
+ }
+
+ void registrationIdEmpty(@OAuth2Client OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAccessTokenAnnotationRegistrationId(@OAuth2Client(registrationId = "client1") OAuth2AccessToken accessToken) {
+ }
+ }
+}
diff --git a/samples/boot/oauth2/authcodegrant/src/main/java/sample/web/MainController.java b/samples/boot/oauth2/authcodegrant/src/main/java/sample/web/GitHubReposController.java
similarity index 73%
rename from samples/boot/oauth2/authcodegrant/src/main/java/sample/web/MainController.java
rename to samples/boot/oauth2/authcodegrant/src/main/java/sample/web/GitHubReposController.java
index f0e525a270..f32bdfe450 100644
--- a/samples/boot/oauth2/authcodegrant/src/main/java/sample/web/MainController.java
+++ b/samples/boot/oauth2/authcodegrant/src/main/java/sample/web/GitHubReposController.java
@@ -15,12 +15,9 @@
*/
package sample.web;
-import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.OAuth2Client;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
@@ -35,9 +32,7 @@ import java.util.List;
* @author Joe Grandja
*/
@Controller
-public class MainController {
- @Autowired
- private OAuth2AuthorizedClientService authorizedClientService;
+public class GitHubReposController {
@GetMapping("/")
public String index() {
@@ -45,16 +40,7 @@ public class MainController {
}
@GetMapping("/repos")
- public String gitHubRepos(Model model, Authentication authentication) {
- String registrationId = "github";
-
- OAuth2AuthorizedClient authorizedClient =
- this.authorizedClientService.loadAuthorizedClient(
- registrationId, authentication.getName());
- if (authorizedClient == null) {
- throw new ClientAuthorizationRequiredException(registrationId);
- }
-
+ public String gitHubRepos(Model model, @OAuth2Client("github") OAuth2AuthorizedClient authorizedClient) {
String endpointUri = "https://api.github.com/user/repos";
List repos = WebClient.builder()
.filter(oauth2Credentials(authorizedClient))
diff --git a/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java b/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java
index 90d96e3fd1..1c722da290 100644
--- a/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java
+++ b/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java
@@ -145,7 +145,7 @@ public class OAuth2LoginApplicationTests {
}
@Test
- public void requestAuthorizeClientWhenInvalidClientThenStatusBadRequest() throws Exception {
+ public void requestAuthorizeClientWhenInvalidClientThenStatusInternalServerError() throws Exception {
HtmlPage page = this.webClient.getPage("/");
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
@@ -161,7 +161,7 @@ public class OAuth2LoginApplicationTests {
response = ex.getResponse();
}
- assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+ assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
}
@Test
diff --git a/samples/boot/oauth2login/src/main/java/sample/web/MainController.java b/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java
similarity index 66%
rename from samples/boot/oauth2login/src/main/java/sample/web/MainController.java
rename to samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java
index 45b93d678b..4f8fa14ca8 100644
--- a/samples/boot/oauth2login/src/main/java/sample/web/MainController.java
+++ b/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,15 +15,13 @@
*/
package sample.web;
-import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.annotation.OAuth2Client;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.util.StringUtils;
-import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.WebClient;
@@ -36,22 +34,17 @@ import java.util.Map;
* @author Joe Grandja
*/
@Controller
-public class MainController {
+public class OAuth2LoginController {
- @Autowired
- private OAuth2AuthorizedClientService authorizedClientService;
-
- @RequestMapping("/")
- public String index(Model model, OAuth2AuthenticationToken authentication) {
- OAuth2AuthorizedClient authorizedClient = this.getAuthorizedClient(authentication);
- model.addAttribute("userName", authentication.getName());
+ @GetMapping("/")
+ public String index(Model model, @OAuth2Client OAuth2AuthorizedClient authorizedClient) {
+ model.addAttribute("userName", authorizedClient.getPrincipalName());
model.addAttribute("clientName", authorizedClient.getClientRegistration().getClientName());
return "index";
}
- @RequestMapping("/userinfo")
- public String userinfo(Model model, OAuth2AuthenticationToken authentication) {
- OAuth2AuthorizedClient authorizedClient = this.getAuthorizedClient(authentication);
+ @GetMapping("/userinfo")
+ public String userinfo(Model model, @OAuth2Client OAuth2AuthorizedClient authorizedClient) {
Map userAttributes = Collections.emptyMap();
String userInfoEndpointUri = authorizedClient.getClientRegistration()
.getProviderDetails().getUserInfoEndpoint().getUri();
@@ -69,11 +62,6 @@ public class MainController {
return "userinfo";
}
- private OAuth2AuthorizedClient getAuthorizedClient(OAuth2AuthenticationToken authentication) {
- return this.authorizedClientService.loadAuthorizedClient(
- authentication.getAuthorizedClientRegistrationId(), authentication.getName());
- }
-
private ExchangeFilterFunction oauth2Credentials(OAuth2AuthorizedClient authorizedClient) {
return ExchangeFilterFunction.ofRequestProcessor(
clientRequest -> {