URL encode client credentials

Closes gh-9610
This commit is contained in:
Steve Riesenberg 2021-05-21 15:03:10 -05:00 committed by Steve Riesenberg
parent 68f91edbb8
commit ac9b137cad
4 changed files with 106 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,9 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
@ -97,7 +100,19 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod()) if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { || ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); String clientId = encodeClientCredential(clientRegistration.getClientId());
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
headers.setBasicAuth(clientId, clientSecret);
}
}
private static String encodeClientCredential(String clientCredential) {
try {
return URLEncoder.encode(clientCredential, StandardCharsets.UTF_8.toString());
}
catch (UnsupportedEncodingException ex) {
// Will not happen since UTF-8 is a standard charset
throw new IllegalArgumentException(ex);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,9 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections; import java.util.Collections;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
@ -48,11 +51,23 @@ final class OAuth2AuthorizationGrantRequestEntityUtils {
headers.addAll(DEFAULT_TOKEN_REQUEST_HEADERS); headers.addAll(DEFAULT_TOKEN_REQUEST_HEADERS);
if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod()) if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { || ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); String clientId = encodeClientCredential(clientRegistration.getClientId());
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
headers.setBasicAuth(clientId, clientSecret);
} }
return headers; return headers;
} }
private static String encodeClientCredential(String clientCredential) {
try {
return URLEncoder.encode(clientCredential, StandardCharsets.UTF_8.toString());
}
catch (UnsupportedEncodingException ex) {
// Will not happen since UTF-8 is a standard charset
throw new IllegalArgumentException(ex);
}
}
private static HttpHeaders getDefaultTokenRequestHeaders() { private static HttpHeaders getDefaultTokenRequestHeaders() {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));

View File

@ -16,6 +16,11 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
@ -128,4 +133,37 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests {
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes()); assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes());
} }
// gh-9610
@SuppressWarnings("unchecked")
@Test
public void convertWhenSpecialCharactersThenConvertsWithEncodedClientCredentials()
throws UnsupportedEncodingException {
String clientCredentialWithAnsiKeyboardSpecialCharacters = "~!@#$%^&*()_+{}|:\"<>?`-=[]\\;',./ ";
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientId(clientCredentialWithAnsiKeyboardSpecialCharacters)
.clientSecret(clientCredentialWithAnsiKeyboardSpecialCharacters)
.build();
// @formatter:on
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
RequestEntity<?> requestEntity = this.converter.convert(clientCredentialsGrantRequest);
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
assertThat(requestEntity.getUrl().toASCIIString())
.isEqualTo(clientRegistration.getProviderDetails().getTokenUri());
HttpHeaders headers = requestEntity.getHeaders();
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
assertThat(headers.getContentType())
.isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
String urlEncodedClientCredential = URLEncoder.encode(clientCredentialWithAnsiKeyboardSpecialCharacters,
StandardCharsets.UTF_8.toString());
String clientCredentials = Base64.getEncoder().encodeToString(
(urlEncodedClientCredential + ":" + urlEncodedClientCredential).getBytes(StandardCharsets.UTF_8));
assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic " + clientCredentials);
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE))
.isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes());
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,10 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest; import okhttp3.mockwebserver.RecordedRequest;
@ -89,6 +93,35 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser"); assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser");
} }
// gh-9610
@Test
public void getTokenResponseWhenSpecialCharactersThenSuccessWithEncodedClientCredentials() throws Exception {
// @formatter:off
enqueueJson("{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
String clientCredentialWithAnsiKeyboardSpecialCharacters = "~!@#$%^&*()_+{}|:\"<>?`-=[]\\;',./ ";
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.clientId(clientCredentialWithAnsiKeyboardSpecialCharacters)
.clientSecret(clientCredentialWithAnsiKeyboardSpecialCharacters).build());
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest();
String body = actualRequest.getBody().readUtf8();
assertThat(response.getAccessToken()).isNotNull();
String urlEncodedClientCredentialecret = URLEncoder.encode(clientCredentialWithAnsiKeyboardSpecialCharacters,
StandardCharsets.UTF_8.toString());
String clientCredentials = Base64.getEncoder()
.encodeToString((urlEncodedClientCredentialecret + ":" + urlEncodedClientCredentialecret)
.getBytes(StandardCharsets.UTF_8));
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic " + clientCredentials);
assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser");
}
@Test @Test
public void getTokenResponseWhenPostThenSuccess() throws Exception { public void getTokenResponseWhenPostThenSuccess() throws Exception {
ClientRegistration registration = this.clientRegistration ClientRegistration registration = this.clientRegistration