Add OAuth2AuthorizedClientExchangeFilterFunction

Fixes: gh-5386
This commit is contained in:
Rob Winch 2018-05-25 09:25:26 -05:00
parent 2658577396
commit c68cf991ae
3 changed files with 232 additions and 0 deletions

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.springframework.http.HttpHeaders;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import reactor.core.publisher.Mono;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
/**
* Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
* token as a Bearer Token.
*
* @author Rob Winch
* @since 5.1
*/
public final class OAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
/**
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
*/
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
* providing the Bearer Token. Example usage:
*
* <pre>
* Mono<String> response = this.webClient
* .get()
* .uri(uri)
* .attributes(oauth2AuthorizedClient(authorizedClient))
* // ...
* .retrieve()
* .bodyToMono(String.class);
* </pre>
*
* @param authorizedClient the {@link OAuth2AuthorizedClient} to use.
* @return the {@link Consumer} to populate the
*/
public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
}
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
return attribute
.map(authorizedClient -> bearer(request, authorizedClient))
.map(next::exchange)
.orElseGet(() -> next.exchange(request));
}
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
return ClientRequest.from(request)
.headers(bearerToken(authorizedClient.getAccessToken().getTokenValue()))
.build();
}
private Consumer<HttpHeaders> bearerToken(String token) {
return headers -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + token);
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import static org.mockito.Mockito.mock;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import reactor.core.publisher.Mono;
/**
* @author Rob Winch
* @since 5.1
*/
public class MockExchangeFunction implements ExchangeFunction {
private ClientRequest request;
private ClientResponse response = mock(ClientResponse.class);
public ClientRequest getRequest() {
return this.request;
}
@Override
public Mono<ClientResponse> exchange(ClientRequest request) {
return Mono.defer(() -> {
this.request = request;
return Mono.just(this.response);
});
}
}

View File

@ -0,0 +1,101 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.web.reactive.function.client.ClientRequest;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
/**
* @author Rob Winch
* @since 5.1
*/
public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
private OAuth2AuthorizedClientExchangeFilterFunction function = new OAuth2AuthorizedClientExchangeFilterFunction();
private MockExchangeFunction exchange = new MockExchangeFunction();
private ClientRegistration github = ClientRegistration.withRegistrationId("github")
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://github.com/login/oauth/authorize")
.tokenUri("https://github.com/login/oauth/access_token")
.userInfoUri("https://api.github.com/user")
.userNameAttributeName("id")
.clientName("GitHub")
.clientId("clientId")
.clientSecret("clientSecret")
.build();
private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
}
@Test
public void filterWhenAuthorizedClientThenAuthorizationHeader() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
"principalName", this.accessToken);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
this.function.filter(request, this.exchange).block();
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
"principalName", this.accessToken);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.header(HttpHeaders.AUTHORIZATION, "Existing")
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
this.function.filter(request, this.exchange).block();
HttpHeaders headers = this.exchange.getRequest().headers();
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}
}