Add Servlet and ServerBearerExchangeFilterFunction

Fixes gh-5334
Fixes gh-7284
This commit is contained in:
Josh Cummings 2019-08-27 09:30:26 -06:00
parent dbd1819ea4
commit f350988285
5 changed files with 646 additions and 0 deletions

View File

@ -0,0 +1,248 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.resource.web;
import java.util.Map;
import java.util.function.Consumer;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.util.context.Context;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
/**
* An {@link ExchangeFilterFunction} that adds the
* <a href="https://tools.ietf.org/html/rfc6750#section-1.2" target="_blank">Bearer Token</a>
* from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}.
*
* Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient}
* configuration:
*
* <pre>
* @Bean
* WebClient webClient() {
* ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction();
* return WebClient.builder()
* .apply(bearer.oauth2Configuration())
* .build();
* }
* </pre>
*
* @author Josh Cummings
* @since 5.2
*/
public class ServletBearerExchangeFilterFunction
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
/**
* {@inheritDoc}
*/
@Override
public void afterPropertiesSet() throws Exception {
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY,
Operators.liftPublisher((s, sub) -> createRequestContextSubscriber(sub)));
}
/**
* {@inheritDoc}
*/
@Override
public void destroy() throws Exception {
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
}
/**
* Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction}
* @return the {@link Consumer} to configure the builder
*/
public Consumer<WebClient.Builder> oauth2Configuration() {
return builder -> builder.defaultRequest(defaultRequest()).filter(this);
}
/**
* Provides defaults for the {@link Authentication} using
* {@link SecurityContextHolder}. It also can default the {@link AbstractOAuth2Token} using the
* {@link #authentication(Authentication)}.
* @return the {@link Consumer} to populate the attributes
*/
public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
return spec -> spec.attributes(attrs -> {
populateDefaultAuthentication(attrs);
});
}
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to
* look up and save the {@link AbstractOAuth2Token}. The value is defaulted in
* {@link ServletBearerExchangeFilterFunction#defaultRequest()}
*
* @param authentication the {@link Authentication} to use.
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication);
}
/**
* {@inheritDoc}
*/
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return mergeRequestAttributesIfNecessary(request)
.filter(req -> req.attribute(AUTHENTICATION_ATTR_NAME).isPresent())
.map(req -> getOAuth2Token(req.attributes()))
.map(token -> bearer(request, token))
.flatMap(next::exchange)
.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
}
private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
if (request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) {
return Mono.just(request);
}
return mergeRequestAttributesFromContext(request);
}
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
ClientRequest.Builder builder = ClientRequest.from(request);
return Mono.subscriberContext()
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))
.map(ClientRequest.Builder::build);
}
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
if (holder == null) {
return;
}
if (holder.getAuthentication() != null) {
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, holder.getAuthentication());
}
}
private AbstractOAuth2Token getOAuth2Token(Map<String, Object> attrs) {
Authentication authentication = (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME);
if (authentication.getCredentials() instanceof AbstractOAuth2Token) {
return (AbstractOAuth2Token) authentication.getCredentials();
}
return null;
}
private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) {
return ClientRequest.from(request)
.headers(headers -> headers.setBearerAuth(token.getTokenValue()))
.build();
}
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return new RequestContextSubscriber<>(delegate, authentication);
}
private void populateDefaultAuthentication(Map<String, Object> attrs) {
if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
return;
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}
private static class RequestContextDataHolder {
private final Authentication authentication;
RequestContextDataHolder(Authentication authentication) {
this.authentication = authentication;
}
public Authentication getAuthentication() {
return this.authentication;
}
}
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
private static final String REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME =
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
private CoreSubscriber<T> delegate;
private final Context context;
private RequestContextSubscriber(CoreSubscriber<T> delegate,
Authentication authentication) {
this.delegate = delegate;
Context parentContext = this.delegate.currentContext();
Context context;
if (authentication == null || parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME)) {
context = parentContext;
} else {
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME,
new RequestContextDataHolder(authentication));
}
this.context = context;
}
@Nullable
static RequestContextDataHolder getRequestContext(Context ctx) {
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, null);
}
@Override
public Context currentContext() {
return this.context;
}
@Override
public void onSubscribe(Subscription s) {
this.delegate.onSubscribe(s);
}
@Override
public void onNext(T t) {
this.delegate.onNext(t);
}
@Override
public void onError(Throwable t) {
this.delegate.onError(t);
}
@Override
public void onComplete() {
this.delegate.onComplete();
}
}
}

View File

@ -0,0 +1,117 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.resource.web.server;
import java.util.Map;
import java.util.function.Consumer;
import reactor.core.publisher.Mono;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
/**
* An {@link ExchangeFilterFunction} that adds the
* <a href="https://tools.ietf.org/html/rfc6750#section-1.2" target="_blank">Bearer Token</a>
* from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}.
*
* Suitable for Reactive applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient}
* configuration:
*
* <pre>
* @Bean
* WebClient webClient() {
* ServerBearerExchangeFilterFunction bearer = new ServerBearerExchangeFilterFunction();
* return WebClient.builder()
* .filter(bearer).build();
* }
* </pre>
*
* @author Josh Cummings
* @since 5.2
*/
public class ServerBearerExchangeFilterFunction
implements ExchangeFilterFunction {
private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER"));
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} to be used for
* providing the Bearer Token. Example usage:
*
* <pre>
* WebClient webClient = WebClient.builder()
* .filter(new ServerBearerExchangeFilterFunction())
* .build();
* Mono<String> response = webClient
* .get()
* .uri(uri)
* .attributes(authentication(authentication))
* // ...
* .retrieve()
* .bodyToMono(String.class);
* </pre>
* @param authentication the {@link Authentication} to use
* @return the {@link Consumer} to populate the client request attributes
*/
public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication);
}
/**
* {@inheritDoc}
*/
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return oauth2Token(request.attributes())
.map(oauth2Token -> bearer(request, oauth2Token))
.defaultIfEmpty(request)
.flatMap(next::exchange);
}
private Mono<AbstractOAuth2Token> oauth2Token(Map<String, Object> attrs) {
return Mono.justOrEmpty(attrs.get(AUTHENTICATION_ATTR_NAME))
.cast(Authentication.class)
.switchIfEmpty(currentAuthentication())
.filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token)
.map(Authentication::getCredentials)
.cast(AbstractOAuth2Token.class);
}
private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
}
private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) {
return ClientRequest.from(request)
.headers(headers -> headers.setBearerAuth(token.getTokenValue()))
.build();
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.resource.web;
import java.util.ArrayList;
import java.util.List;
import reactor.core.publisher.Mono;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import static org.mockito.Mockito.mock;
/**
* @author Rob Winch
* @since 5.1
*/
public class MockExchangeFunction implements ExchangeFunction {
private List<ClientRequest> requests = new ArrayList<>();
private ClientResponse response = mock(ClientResponse.class);
public ClientRequest getRequest() {
return this.requests.get(this.requests.size() - 1);
}
public List<ClientRequest> getRequests() {
return this.requests;
}
public ClientResponse getResponse() {
return this.response;
}
@Override
public Mono<ClientResponse> exchange(ClientRequest request) {
return Mono.defer(() -> {
this.requests.add(request);
return Mono.just(this.response);
});
}
}

View File

@ -0,0 +1,116 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.resource.web;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken;
import org.springframework.web.reactive.function.client.ClientRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication;
/**
* Tests for {@link ServletBearerExchangeFilterFunction}
*
* @author Josh Cummings
*/
@RunWith(MockitoJUnitRunner.class)
public class ServletBearerExchangeFilterFunctionTests {
private ServletBearerExchangeFilterFunction function = new ServletBearerExchangeFilterFunction();
private MockExchangeFunction exchange = new MockExchangeFunction();
private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"token-0",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken<OAuth2AccessToken>(accessToken) {
@Override
public Map<String, Object> getTokenAttributes() {
return Collections.emptyMap();
}
};
@After
public void cleanup() {
SecurityContextHolder.clearContext();
}
@Test
public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
}
@Test
public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception {
this.function.afterPropertiesSet();
SecurityContextHolder.getContext().setAuthentication(this.authentication);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION))
.isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenAuthenticationAttributeThenAuthorizationHeader() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION))
.isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.header(HttpHeaders.AUTHORIZATION, "Existing")
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
HttpHeaders headers = this.exchange.getRequest().headers();
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}
}

View File

@ -0,0 +1,107 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.resource.web.server;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.web.MockExchangeFunction;
import org.springframework.web.reactive.function.client.ClientRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication;
/**
* Tests for {@link ServerBearerExchangeFilterFunction}
*
* @author Josh Cummings
*/
public class ServerBearerExchangeFilterFunctionTests {
private ServerBearerExchangeFilterFunction function = new ServerBearerExchangeFilterFunction();
private MockExchangeFunction exchange = new MockExchangeFunction();
private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"token-0",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken<OAuth2AccessToken>(accessToken) {
@Override
public Map<String, Object> getTokenAttributes() {
return Collections.emptyMap();
}
};
@Test
public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
}
@Test
public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
.block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION))
.isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenAuthenticationAttributeThenAuthorizationHeader() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION))
.isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.header(HttpHeaders.AUTHORIZATION, "Existing")
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
HttpHeaders headers = this.exchange.getRequest().headers();
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}
}