WebClientReactiveClientCredentialsTokenResponseClient

Fixes: gh-5607
This commit is contained in:
Rob Winch 2018-09-04 15:14:27 -05:00
parent 89f2874bff
commit 28537fa3b6
2 changed files with 233 additions and 0 deletions

View File

@ -0,0 +1,107 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.util.Set;
import java.util.function.Consumer;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
/**
* An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges"
* an authorization code credential for an access token credential
* at the Authorization Server's Token Endpoint.
*
* @author Rob Winch
* @since 5.1
* @see OAuth2AccessTokenResponseClient
* @see OAuth2AuthorizationCodeGrantRequest
* @see OAuth2AccessTokenResponse
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-oauth-openid-connect-sdk">Nimbus OAuth 2.0 SDK</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request (Authorization Code Grant)</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
*/
public class WebClientReactiveClientCredentialsTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
private WebClient webClient = WebClient.builder()
.build();
@Override
public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest)
throws OAuth2AuthenticationException {
return Mono.defer(() -> {
ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
BodyInserters.FormInserter<String> body = body(authorizationGrantRequest);
return this.webClient.post()
.uri(tokenUri)
.accept(MediaType.APPLICATION_JSON)
.headers(headers(clientRegistration))
.body(body)
.exchange()
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
.map(response -> {
if (response.getAccessToken().getScopes().isEmpty()) {
response = OAuth2AccessTokenResponse.withResponse(response)
.scopes(authorizationGrantRequest.getClientRegistration().getScopes())
.build();
}
return response;
});
});
}
private Consumer<HttpHeaders> headers(ClientRegistration clientRegistration) {
return headers -> {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
}
};
}
private static BodyInserters.FormInserter<String> body(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) {
ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
BodyInserters.FormInserter<String> body = BodyInserters
.fromFormData(OAuth2ParameterNames.GRANT_TYPE, authorizationGrantRequest.getGrantType().getValue());
Set<String> scopes = clientRegistration.getScopes();
if (!CollectionUtils.isEmpty(scopes)) {
String scope = StringUtils.collectionToDelimitedString(scopes, " ");
body.with(OAuth2ParameterNames.SCOPE, scope);
}
if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return body;
}
}

View File

@ -0,0 +1,126 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import static org.assertj.core.api.Assertions.*;
/**
* @author Rob Winch
*/
public class WebClientReactiveClientCredentialsTokenResponseClientTests {
private MockWebServer server;
private WebClientReactiveClientCredentialsTokenResponseClient client = new WebClientReactiveClientCredentialsTokenResponseClient();
private ClientRegistration.Builder clientRegistration;
@Before
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = TestClientRegistrations
.clientCredentials()
.tokenUri(this.server.url("/oauth2/token").uri().toASCIIString());
}
@After
public void cleanup() throws Exception {
this.server.shutdown();
}
@Test
public void getTokenResponseWhenHeaderThenSuccess() throws Exception {
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(this.clientRegistration
.build());
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest();
String body = actualRequest.getUtf8Body();
assertThat(response.getAccessToken()).isNotNull();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser");
}
@Test
public void getTokenResponseWhenPostThenSuccess() throws Exception {
ClientRegistration registration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.POST)
.build();
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
String body = this.server.takeRequest().getUtf8Body();
assertThat(response.getAccessToken()).isNotNull();
assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser&client_id=client-id&client_secret=client-secret");
}
@Test
public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() {
ClientRegistration registration = this.clientRegistration.build();
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes());
}
private void enqueueJson(String body) {
MockResponse response = new MockResponse()
.setBody(body)
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(response);
}
}