Add OAuth2LoginSpec

Issue: gh-4807
This commit is contained in:
Rob Winch 2018-05-11 00:40:44 -05:00
parent 23f4b9d3d1
commit 7013c6fd76
7 changed files with 303 additions and 18 deletions

View File

@ -82,7 +82,8 @@ import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
@Documented
@Import({ServerHttpSecurityConfiguration.class, WebFluxSecurityConfiguration.class})
@Import({ServerHttpSecurityConfiguration.class, WebFluxSecurityConfiguration.class,
ReactiveOAuth2ClientImportSelector.class})
@Configuration
public @interface EnableWebFluxSecurity {
}

View File

@ -0,0 +1,80 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.reactive;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2ClientArgumentResolver;
import org.springframework.util.ClassUtils;
import org.springframework.web.reactive.config.WebFluxConfigurer;
import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
/**
* {@link Configuration} for OAuth 2.0 Client support.
*
* <p>
* This {@code Configuration} is imported by {@link EnableWebFluxSecurity}
*
* @author Rob Winch
* @since 5.1
*/
final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
@Override
public String[] selectImports(AnnotationMetadata importingClassMetadata) {
boolean oauth2ClientPresent = ClassUtils.isPresent(
"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader());
return oauth2ClientPresent ?
new String[] { "org.springframework.security.config.annotation.web.reactive.ReactiveOAuth2ClientImportSelector$OAuth2ClientWebFluxSecurityConfiguration" } :
new String[] {};
}
@Configuration
static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
@Override
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
if (this.clientRegistrationRepository != null && this.authorizedClientService != null) {
configurer.addCustomResolver(new OAuth2ClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientService));
}
}
@Autowired(required = false)
public void setClientRegistrationRepository(List<ReactiveClientRegistrationRepository> clientRegistrationRepository) {
if (clientRegistrationRepository.size() == 1) {
this.clientRegistrationRepository = clientRegistrationRepository.get(0);
}
}
@Autowired(required = false)
public void setAuthorizedClientService(List<ReactiveOAuth2AuthorizedClientService> authorizedClientService) {
if (authorizedClientService.size() == 1) {
this.authorizedClientService = authorizedClientService.get(0);
}
}
}
}

View File

@ -16,8 +16,11 @@
package org.springframework.security.config.annotation.web.reactive;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import org.springframework.context.expression.BeanFactoryResolver;
@ -31,8 +34,6 @@ import org.springframework.security.web.reactive.result.method.annotation.Authen
import org.springframework.web.reactive.config.WebFluxConfigurer;
import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
import static org.springframework.security.config.web.server.ServerHttpSecurity.http;
/**
* @author Rob Winch
* @since 5.0
@ -74,7 +75,8 @@ class ServerHttpSecurityConfiguration implements WebFluxConfigurer {
@Bean(HTTPSECURITY_BEAN_NAME)
@Scope("prototype")
public ServerHttpSecurity httpSecurity() {
return http()
ContextAwareServerHttpSecurity http = new ContextAwareServerHttpSecurity();
return http
.authenticationManager(authenticationManager())
.headers().and()
.logout().and();
@ -94,4 +96,13 @@ class ServerHttpSecurityConfiguration implements WebFluxConfigurer {
}
return null;
}
private static class ContextAwareServerHttpSecurity extends ServerHttpSecurity implements
ApplicationContextAware {
@Override
public void setApplicationContext(ApplicationContext applicationContext)
throws BeansException {
super.setApplicationContext(applicationContext);
}
}
}

View File

@ -16,6 +16,9 @@
package org.springframework.security.config.annotation.web.reactive;
import java.util.Arrays;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
@ -25,12 +28,10 @@ import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.result.view.AbstractView;
import java.util.Arrays;
import java.util.List;
/**
* @author Rob Winch
* @since 5.0
@ -43,6 +44,11 @@ class WebFluxSecurityConfiguration {
private static final String SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME = BEAN_NAME_PREFIX + "WebFilterChainFilter";
public static final String REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME = "org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository";
private static final boolean isOAuth2Present = ClassUtils.isPresent(
REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, WebFluxSecurityConfiguration.class.getClassLoader());
@Autowired(required = false)
private List<SecurityWebFilterChain> securityWebFilterChains;
@ -85,6 +91,22 @@ class WebFluxSecurityConfiguration {
.and()
.httpBasic().and()
.formLogin();
return http.build();
if (isOAuth2Present) {
OAuth2ClasspathGuard.configure(this.context, http);
}
SecurityWebFilterChain result = http.build();
return result;
}
private static class OAuth2ClasspathGuard {
static void configure(ApplicationContext context, ServerHttpSecurity http) {
ClassLoader loader = context.getClassLoader();
Class<?> reactiveClientRegistrationRepositoryClass = ClassUtils.resolveClassName(REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, loader);
if (context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1) {
http.oauth2Login();
}
}
}
}

View File

@ -24,9 +24,14 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.Ordered;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -35,12 +40,23 @@ import org.springframework.security.authorization.AuthenticatedReactiveAuthoriza
import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager;
import org.springframework.security.authorization.AuthorizationDecision;
import org.springframework.security.authorization.ReactiveAuthorizationManager;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginReactiveAuthenticationManager;
import org.springframework.security.oauth2.client.endpoint.NimbusReactiveAuthorizationCodeTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.ServerOAuth2LoginAuthenticationTokenConverter;
import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint;
import org.springframework.security.web.server.MatcherSecurityWebFilterChain;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.ServerFormLoginAuthenticationConverter;
import org.springframework.security.web.server.ServerHttpBasicAuthenticationConverter;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
@ -79,6 +95,7 @@ import org.springframework.security.web.server.savedrequest.WebSessionServerRequ
import org.springframework.security.web.server.ui.LoginPageGeneratingWebFilter;
import org.springframework.security.web.server.ui.LogoutPageGeneratingWebFilter;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcherEntry;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
@ -161,6 +178,8 @@ public class ServerHttpSecurity {
private FormLoginSpec formLogin;
private OAuth2LoginSpec oauth2Login;
private LogoutSpec logout = new LogoutSpec();
private ReactiveAuthenticationManager authenticationManager;
@ -175,6 +194,8 @@ public class ServerHttpSecurity {
private List<WebFilter> webFilters = new ArrayList<>();
private ApplicationContext context;
private Throwable built;
/**
@ -318,6 +339,90 @@ public class ServerHttpSecurity {
return this.formLogin;
}
public OAuth2LoginSpec oauth2Login() {
if (this.oauth2Login == null) {
this.oauth2Login = new OAuth2LoginSpec();
}
return this.oauth2Login;
}
public class OAuth2LoginSpec {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
public OAuth2LoginSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
return this;
}
public OAuth2LoginSpec authorizedClientService(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
this.authorizedClientService = authorizedClientService;
return this;
}
protected void configure(LoginPageGeneratingWebFilter loginPageFilter, ServerHttpSecurity http) {
if (loginPageFilter != null) {
loginPageFilter.setOauth2AuthenticationUrlToClientName(getLinks());
}
ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService();
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository);
NimbusReactiveAuthorizationCodeTokenResponseClient client = new NimbusReactiveAuthorizationCodeTokenResponseClient();
ReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
OAuth2LoginReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService,
authorizedClientService);
AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(manager);
authenticationFilter.setRequiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"));
authenticationFilter.setAuthenticationConverter(new ServerOAuth2LoginAuthenticationTokenConverter(clientRegistrationRepository));
RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();
authenticationFilter.setAuthenticationSuccessHandler(redirectHandler);
authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationFailureHandler() {
@Override
public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,
AuthenticationException exception) {
return Mono.error(exception);
}
});
authenticationFilter.setSecurityContextRepository(new WebSessionServerSecurityContextRepository());
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
}
private Map<String, String> getLinks() {
Iterable<ClientRegistration> registrations = getBeanOrNull(ResolvableType.forClassWithGenerics(Iterable.class, ClientRegistration.class));
if (registrations == null) {
return Collections.emptyMap();
}
Map<String, String> result = new HashMap<>();
registrations.iterator().forEachRemaining(r -> {
result.put("/oauth2/authorization/" + r.getRegistrationId(), r.getClientName());
});
return result;
}
private ReactiveClientRegistrationRepository getClientRegistrationRepository() {
if (this.clientRegistrationRepository == null) {
this.clientRegistrationRepository = getBeanOrNull(ReactiveClientRegistrationRepository.class);
}
return this.clientRegistrationRepository;
}
private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
if (this.authorizedClientService == null) {
this.authorizedClientService = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
}
return this.authorizedClientService;
}
private OAuth2LoginSpec() {}
}
/**
* Configures HTTP Response Headers. The default headers are:
*
@ -505,17 +610,22 @@ public class ServerHttpSecurity {
this.httpBasic.authenticationManager(this.authenticationManager);
this.httpBasic.configure(this);
}
LoginPageGeneratingWebFilter loginPageFilter = null;
if(this.formLogin != null) {
this.formLogin.authenticationManager(this.authenticationManager);
if(this.securityContextRepository != null) {
this.formLogin.securityContextRepository(this.securityContextRepository);
}
if(this.formLogin.authenticationEntryPoint == null) {
this.webFilters.add(new OrderedWebFilter(new LoginPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
loginPageFilter = new LoginPageGeneratingWebFilter();
this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder()));
}
this.formLogin.configure(this);
}
if (this.oauth2Login != null) {
this.oauth2Login.configure(loginPageFilter, this);
}
if(this.logout != null) {
this.logout.configure(this);
}
@ -589,7 +699,7 @@ public class ServerHttpSecurity {
return new OrderedWebFilter(result, SecurityWebFiltersOrder.REACTOR_CONTEXT.getOrder());
}
private ServerHttpSecurity() {}
protected ServerHttpSecurity() {}
/**
* Configures authorization
@ -1402,6 +1512,27 @@ public class ServerHttpSecurity {
private LogoutSpec() {}
}
private <T> T getBeanOrNull(Class<T> beanClass) {
return getBeanOrNull(ResolvableType.forClass(beanClass));
}
private <T> T getBeanOrNull(ResolvableType type) {
if (this.context == null) {
return null;
}
String[] names = this.context.getBeanNamesForType(type);
if (names.length == 1) {
return (T) this.context.getBean(names[0]);
}
return null;
}
protected void setApplicationContext(ApplicationContext applicationContext)
throws BeansException {
this.context = applicationContext;
}
private static class OrderedWebFilter implements WebFilter, Ordered {
private final WebFilter webFilter;
private final int order;

View File

@ -8,10 +8,14 @@ dependencies {
compile 'com.nimbusds:oauth2-oidc-sdk'
optional project(':spring-security-oauth2-jose')
optional 'io.projectreactor:reactor-core'
optional 'org.springframework:spring-webflux'
testCompile powerMock2Dependencies
testCompile 'com.squareup.okhttp3:mockwebserver'
testCompile 'com.fasterxml.jackson.core:jackson-databind'
testCompile 'io.projectreactor.ipc:reactor-netty'
testCompile 'io.projectreactor:reactor-test'
provided 'javax.servlet:javax.servlet-api'
}

View File

@ -16,6 +16,10 @@
package org.springframework.security.web.server.ui;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpMethod;
@ -25,13 +29,14 @@ import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import org.springframework.web.util.HtmlUtils;
import java.nio.charset.Charset;
import reactor.core.publisher.Mono;
/**
* Generates a default log in page used for authenticating users.
@ -43,6 +48,14 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
private ServerWebExchangeMatcher matcher = ServerWebExchangeMatchers
.pathMatchers(HttpMethod.GET, "/login");
private Map<String, String> oauth2AuthenticationUrlToClientName = new HashMap<>();
public void setOauth2AuthenticationUrlToClientName(
Map<String, String> oauth2AuthenticationUrlToClientName) {
Assert.notNull(oauth2AuthenticationUrlToClientName, "oauth2AuthenticationUrlToClientName cannot be null");
this.oauth2AuthenticationUrlToClientName = oauth2AuthenticationUrlToClientName;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return this.matcher.matches(exchange)
@ -59,22 +72,24 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
}
private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
return token
.map(LoginPageGeneratingWebFilter::csrfToken)
.defaultIfEmpty("")
.map(csrfTokenHtmlInput -> {
boolean isError = queryParams.containsKey("error");
boolean isLogoutSuccess = queryParams.containsKey("logout");
byte[] bytes = createPage(isError, isLogoutSuccess, csrfTokenHtmlInput);
byte[] bytes = createPage(exchange, csrfTokenHtmlInput);
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
return bufferFactory.wrap(bytes);
});
}
private static byte[] createPage(boolean isError, boolean isLogoutSuccess, String csrfTokenHtmlInput) {
private byte[] createPage(ServerWebExchange exchange, String csrfTokenHtmlInput) {
MultiValueMap<String, String> queryParams = exchange.getRequest()
.getQueryParams();
boolean isError = queryParams.containsKey("error");
boolean isLogoutSuccess = queryParams.containsKey("logout");
String contextPath = exchange.getRequest().getPath().contextPath().value();
String page = "<!DOCTYPE html>\n"
+ "<html lang=\"en\">\n"
+ " <head>\n"
@ -103,6 +118,7 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
+ csrfTokenHtmlInput
+ " <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
+ " </form>\n"
+ oauth2LoginLinks(contextPath, this.oauth2AuthenticationUrlToClientName)
+ " </div>\n"
+ " </body>\n"
+ "</html>";
@ -110,6 +126,26 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
return page.getBytes(Charset.defaultCharset());
}
private static String oauth2LoginLinks(String contextPath, Map<String, String> oauth2AuthenticationUrlToClientName) {
if (oauth2AuthenticationUrlToClientName.isEmpty()) {
return "";
}
StringBuilder sb = new StringBuilder();
sb.append("<div class=\"container\"><h2 class=\"form-signin-heading\">Login with OAuth 2.0</h3>");
sb.append("<table class=\"table table-striped\">\n");
for (Map.Entry<String, String> clientAuthenticationUrlToClientName : oauth2AuthenticationUrlToClientName.entrySet()) {
sb.append(" <tr><td>");
String url = clientAuthenticationUrlToClientName.getKey();
sb.append("<a href=\"").append(contextPath).append(url).append("\">");
String clientName = HtmlUtils.htmlEscape(clientAuthenticationUrlToClientName.getValue());
sb.append(clientName);
sb.append("</a>");
sb.append("</td></tr>\n");
}
sb.append("</table></div>\n");
return sb.toString();
}
private static String csrfToken(CsrfToken token) {
return " <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n";
}