Add OAuth2AuthorizedClientExchangeFilterFunction
Fixes: gh-5386
This commit is contained in:
parent
2658577396
commit
c68cf991ae
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
|
||||
* token as a Bearer Token.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @since 5.1
|
||||
*/
|
||||
public final class OAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
|
||||
/**
|
||||
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
|
||||
*/
|
||||
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
|
||||
|
||||
/**
|
||||
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
|
||||
* providing the Bearer Token. Example usage:
|
||||
*
|
||||
* <pre>
|
||||
* Mono<String> response = this.webClient
|
||||
* .get()
|
||||
* .uri(uri)
|
||||
* .attributes(oauth2AuthorizedClient(authorizedClient))
|
||||
* // ...
|
||||
* .retrieve()
|
||||
* .bodyToMono(String.class);
|
||||
* </pre>
|
||||
*
|
||||
* @param authorizedClient the {@link OAuth2AuthorizedClient} to use.
|
||||
* @return the {@link Consumer} to populate the
|
||||
*/
|
||||
public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
|
||||
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
||||
.map(OAuth2AuthorizedClient.class::cast);
|
||||
return attribute
|
||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||
.map(next::exchange)
|
||||
.orElseGet(() -> next.exchange(request));
|
||||
}
|
||||
|
||||
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
|
||||
return ClientRequest.from(request)
|
||||
.headers(bearerToken(authorizedClient.getAccessToken().getTokenValue()))
|
||||
.build();
|
||||
}
|
||||
|
||||
private Consumer<HttpHeaders> bearerToken(String token) {
|
||||
return headers -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + token);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
||||
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
/**
|
||||
* @author Rob Winch
|
||||
* @since 5.1
|
||||
*/
|
||||
public class MockExchangeFunction implements ExchangeFunction {
|
||||
private ClientRequest request;
|
||||
|
||||
private ClientResponse response = mock(ClientResponse.class);
|
||||
|
||||
public ClientRequest getRequest() {
|
||||
return this.request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<ClientResponse> exchange(ClientRequest request) {
|
||||
return Mono.defer(() -> {
|
||||
this.request = request;
|
||||
return Mono.just(this.response);
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
|
||||
|
||||
/**
|
||||
* @author Rob Winch
|
||||
* @since 5.1
|
||||
*/
|
||||
public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
private OAuth2AuthorizedClientExchangeFilterFunction function = new OAuth2AuthorizedClientExchangeFilterFunction();
|
||||
|
||||
private MockExchangeFunction exchange = new MockExchangeFunction();
|
||||
|
||||
private ClientRegistration github = ClientRegistration.withRegistrationId("github")
|
||||
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.scope("read:user")
|
||||
.authorizationUri("https://github.com/login/oauth/authorize")
|
||||
.tokenUri("https://github.com/login/oauth/access_token")
|
||||
.userInfoUri("https://api.github.com/user")
|
||||
.userNameAttributeName("id")
|
||||
.clientName("GitHub")
|
||||
.clientId("clientId")
|
||||
.clientSecret("clientSecret")
|
||||
.build();
|
||||
|
||||
private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
|
||||
"token",
|
||||
Instant.now(),
|
||||
Instant.now().plus(Duration.ofDays(1)));
|
||||
|
||||
@Test
|
||||
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenAuthorizedClientThenAuthorizationHeader() {
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
|
||||
"principalName", this.accessToken);
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.attributes(oauth2AuthorizedClient(authorizedClient))
|
||||
.build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
|
||||
"principalName", this.accessToken);
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.header(HttpHeaders.AUTHORIZATION, "Existing")
|
||||
.attributes(oauth2AuthorizedClient(authorizedClient))
|
||||
.build();
|
||||
|
||||
this.function.filter(request, this.exchange).block();
|
||||
|
||||
HttpHeaders headers = this.exchange.getRequest().headers();
|
||||
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue