From f35098828506d78c8899d884ba778d184ddb0e4c Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 27 Aug 2019 09:30:26 -0600 Subject: [PATCH] Add Servlet and ServerBearerExchangeFilterFunction Fixes gh-5334 Fixes gh-7284 --- .../ServletBearerExchangeFilterFunction.java | 248 ++++++++++++++++++ .../ServerBearerExchangeFilterFunction.java | 117 +++++++++ .../resource/web/MockExchangeFunction.java | 58 ++++ ...vletBearerExchangeFilterFunctionTests.java | 116 ++++++++ ...rverBearerExchangeFilterFunctionTests.java | 107 ++++++++ 5 files changed, 646 insertions(+) create mode 100644 oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java create mode 100644 oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java create mode 100644 oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java create mode 100644 oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java create mode 100644 oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java new file mode 100644 index 0000000000..820c05ac48 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunction.java @@ -0,0 +1,248 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web; + +import java.util.Map; +import java.util.function.Consumer; + +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * An {@link ExchangeFilterFunction} that adds the + * Bearer Token + * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. + * + * Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} + * configuration: + * + *
+ *  @Bean
+ *  WebClient webClient() {
+ *      ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction();
+ *      return WebClient.builder()
+ *              .apply(bearer.oauth2Configuration())
+ *              .build();
+ *  }
+ * 
+ * + * @author Josh Cummings + * @since 5.2 + */ +public class ServletBearerExchangeFilterFunction + implements ExchangeFilterFunction, InitializingBean, DisposableBean { + + private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); + + private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); + + /** + * {@inheritDoc} + */ + @Override + public void afterPropertiesSet() throws Exception { + Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, + Operators.liftPublisher((s, sub) -> createRequestContextSubscriber(sub))); + } + + /** + * {@inheritDoc} + */ + @Override + public void destroy() throws Exception { + Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); + } + + /** + * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction} + * @return the {@link Consumer} to configure the builder + */ + public Consumer oauth2Configuration() { + return builder -> builder.defaultRequest(defaultRequest()).filter(this); + } + + /** + * Provides defaults for the {@link Authentication} using + * {@link SecurityContextHolder}. It also can default the {@link AbstractOAuth2Token} using the + * {@link #authentication(Authentication)}. + * @return the {@link Consumer} to populate the attributes + */ + public Consumer> defaultRequest() { + return spec -> spec.attributes(attrs -> { + populateDefaultAuthentication(attrs); + }); + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to + * look up and save the {@link AbstractOAuth2Token}. The value is defaulted in + * {@link ServletBearerExchangeFilterFunction#defaultRequest()} + * + * @param authentication the {@link Authentication} to use. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> authentication(Authentication authentication) { + return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); + } + + /** + * {@inheritDoc} + */ + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + return mergeRequestAttributesIfNecessary(request) + .filter(req -> req.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) + .map(req -> getOAuth2Token(req.attributes())) + .map(token -> bearer(request, token)) + .flatMap(next::exchange) + .switchIfEmpty(Mono.defer(() -> next.exchange(request))); + } + + private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { + if (request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { + return Mono.just(request); + } + + return mergeRequestAttributesFromContext(request); + } + + private Mono mergeRequestAttributesFromContext(ClientRequest request) { + ClientRequest.Builder builder = ClientRequest.from(request); + return Mono.subscriberContext() + .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) + .map(ClientRequest.Builder::build); + } + + private void populateRequestAttributes(Map attrs, Context ctx) { + RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx); + if (holder == null) { + return; + } + if (holder.getAuthentication() != null) { + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, holder.getAuthentication()); + } + } + + private AbstractOAuth2Token getOAuth2Token(Map attrs) { + Authentication authentication = (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME); + if (authentication.getCredentials() instanceof AbstractOAuth2Token) { + return (AbstractOAuth2Token) authentication.getCredentials(); + } + return null; + } + + private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { + return ClientRequest.from(request) + .headers(headers -> headers.setBearerAuth(token.getTokenValue())) + .build(); + } + + private CoreSubscriber createRequestContextSubscriber(CoreSubscriber delegate) { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + return new RequestContextSubscriber<>(delegate, authentication); + } + + private void populateDefaultAuthentication(Map attrs) { + if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) { + return; + } + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); + } + + private static class RequestContextDataHolder { + private final Authentication authentication; + + RequestContextDataHolder(Authentication authentication) { + this.authentication = authentication; + } + + public Authentication getAuthentication() { + return this.authentication; + } + } + + private static class RequestContextSubscriber implements CoreSubscriber { + private static final String REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME = + RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER"); + + private CoreSubscriber delegate; + private final Context context; + + private RequestContextSubscriber(CoreSubscriber delegate, + Authentication authentication) { + + this.delegate = delegate; + Context parentContext = this.delegate.currentContext(); + Context context; + if (authentication == null || parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME)) { + context = parentContext; + } else { + context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, + new RequestContextDataHolder(authentication)); + } + + this.context = context; + } + + @Nullable + static RequestContextDataHolder getRequestContext(Context ctx) { + return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, null); + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription s) { + this.delegate.onSubscribe(s); + } + + @Override + public void onNext(T t) { + this.delegate.onNext(t); + } + + @Override + public void onError(Throwable t) { + this.delegate.onError(t); + } + + @Override + public void onComplete() { + this.delegate.onComplete(); + } + } +} diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java new file mode 100644 index 0000000000..0cb37fa85c --- /dev/null +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunction.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web.server; + +import java.util.Map; +import java.util.function.Consumer; + +import reactor.core.publisher.Mono; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; + +/** + * An {@link ExchangeFilterFunction} that adds the + * Bearer Token + * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. + * + * Suitable for Reactive applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} + * configuration: + * + *
+ *  @Bean
+ *  WebClient webClient() {
+ *      ServerBearerExchangeFilterFunction bearer = new ServerBearerExchangeFilterFunction();
+ *      return WebClient.builder()
+ *              .filter(bearer).build();
+ *  }
+ * 
+ * + * @author Josh Cummings + * @since 5.2 + */ +public class ServerBearerExchangeFilterFunction + implements ExchangeFilterFunction { + + private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} to be used for + * providing the Bearer Token. Example usage: + * + *
+	 * WebClient webClient = WebClient.builder()
+	 *    .filter(new ServerBearerExchangeFilterFunction())
+	 *    .build();
+	 * Mono response = webClient
+	 *    .get()
+	 *    .uri(uri)
+	 *    .attributes(authentication(authentication))
+	 *    // ...
+	 *    .retrieve()
+	 *    .bodyToMono(String.class);
+	 * 
+ * @param authentication the {@link Authentication} to use + * @return the {@link Consumer} to populate the client request attributes + */ + public static Consumer> authentication(Authentication authentication) { + return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); + } + + /** + * {@inheritDoc} + */ + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + return oauth2Token(request.attributes()) + .map(oauth2Token -> bearer(request, oauth2Token)) + .defaultIfEmpty(request) + .flatMap(next::exchange); + } + + private Mono oauth2Token(Map attrs) { + return Mono.justOrEmpty(attrs.get(AUTHENTICATION_ATTR_NAME)) + .cast(Authentication.class) + .switchIfEmpty(currentAuthentication()) + .filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token) + .map(Authentication::getCredentials) + .cast(AbstractOAuth2Token.class); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { + return ClientRequest.from(request) + .headers(headers -> headers.setBearerAuth(token.getTokenValue())) + .build(); + } +} diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java new file mode 100644 index 0000000000..a4da50ea00 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web; + +import java.util.ArrayList; +import java.util.List; + +import reactor.core.publisher.Mono; + +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFunction; + +import static org.mockito.Mockito.mock; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class MockExchangeFunction implements ExchangeFunction { + private List requests = new ArrayList<>(); + + private ClientResponse response = mock(ClientResponse.class); + + public ClientRequest getRequest() { + return this.requests.get(this.requests.size() - 1); + } + + public List getRequests() { + return this.requests; + } + + public ClientResponse getResponse() { + return this.response; + } + + @Override + public Mono exchange(ClientRequest request) { + return Mono.defer(() -> { + this.requests.add(request); + return Mono.just(this.response); + }); + } +} diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java new file mode 100644 index 0000000000..f1217a63a5 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/ServletBearerExchangeFilterFunctionTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.http.HttpHeaders; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.web.reactive.function.client.ClientRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; + +/** + * Tests for {@link ServletBearerExchangeFilterFunction} + * + * @author Josh Cummings + */ +@RunWith(MockitoJUnitRunner.class) +public class ServletBearerExchangeFilterFunctionTests { + private ServletBearerExchangeFilterFunction function = new ServletBearerExchangeFilterFunction(); + + private MockExchangeFunction exchange = new MockExchangeFunction(); + + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token-0", + Instant.now(), + Instant.now().plus(Duration.ofDays(1))); + private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken(accessToken) { + @Override + public Map getTokenAttributes() { + return Collections.emptyMap(); + } + }; + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + } + + @Test + public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { + this.function.afterPropertiesSet(); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .header(HttpHeaders.AUTHORIZATION, "Existing") + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + HttpHeaders headers = this.exchange.getRequest().headers(); + assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); + } +} diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java new file mode 100644 index 0000000000..0a5ac9b0f5 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerExchangeFilterFunctionTests.java @@ -0,0 +1,107 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.web.server; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.security.oauth2.server.resource.web.MockExchangeFunction; +import org.springframework.web.reactive.function.client.ClientRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; + +/** + * Tests for {@link ServerBearerExchangeFilterFunction} + * + * @author Josh Cummings + */ +public class ServerBearerExchangeFilterFunctionTests { + private ServerBearerExchangeFilterFunction function = new ServerBearerExchangeFilterFunction(); + + private MockExchangeFunction exchange = new MockExchangeFunction(); + + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token-0", + Instant.now(), + Instant.now().plus(Duration.ofDays(1))); + private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken(accessToken) { + @Override + public Map getTokenAttributes() { + return Collections.emptyMap(); + } + }; + + @Test + public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + } + + @Test + public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .header(HttpHeaders.AUTHORIZATION, "Existing") + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + HttpHeaders headers = this.exchange.getRequest().headers(); + assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); + } +}