Enable customizing headers in token requests
Adds the possibility to customize the headers of the access token request in AbstractWebClientReactiveOAuth2AccessTokenResponseClient, similarly to what is done in the AbstractOAuth2AuthorizationGrantRequestEntityConverter. Closes gh-10130
This commit is contained in:
parent
1806cebd64
commit
044157061f
|
@ -24,6 +24,7 @@ import java.util.Set;
|
||||||
|
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
|
@ -65,6 +66,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
|
||||||
|
|
||||||
private WebClient webClient = WebClient.builder().build();
|
private WebClient webClient = WebClient.builder().build();
|
||||||
|
|
||||||
|
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
|
||||||
|
|
||||||
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
|
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +77,12 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
|
||||||
// @formatter:off
|
// @formatter:off
|
||||||
return Mono.defer(() -> this.webClient.post()
|
return Mono.defer(() -> this.webClient.post()
|
||||||
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
|
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
|
||||||
.headers((headers) -> populateTokenRequestHeaders(grantRequest, headers))
|
.headers((headers) -> {
|
||||||
|
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
|
||||||
|
if (headersToAdd != null) {
|
||||||
|
headers.addAll(headersToAdd);
|
||||||
|
}
|
||||||
|
})
|
||||||
.body(createTokenRequestBody(grantRequest))
|
.body(createTokenRequestBody(grantRequest))
|
||||||
.exchange()
|
.exchange()
|
||||||
.flatMap((response) -> readTokenResponse(grantRequest, response))
|
.flatMap((response) -> readTokenResponse(grantRequest, response))
|
||||||
|
@ -92,9 +100,10 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
|
||||||
/**
|
/**
|
||||||
* Populates the headers for the token request.
|
* Populates the headers for the token request.
|
||||||
* @param grantRequest the grant request
|
* @param grantRequest the grant request
|
||||||
* @param headers the headers to populate
|
* @return the headers populated for the token request
|
||||||
*/
|
*/
|
||||||
private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
|
private HttpHeaders populateTokenRequestHeaders(T grantRequest) {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
ClientRegistration clientRegistration = clientRegistration(grantRequest);
|
ClientRegistration clientRegistration = clientRegistration(grantRequest);
|
||||||
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
|
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
|
||||||
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
|
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
|
||||||
|
@ -104,6 +113,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
|
||||||
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
|
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
|
||||||
headers.setBasicAuth(clientId, clientSecret);
|
headers.setBasicAuth(clientId, clientSecret);
|
||||||
}
|
}
|
||||||
|
return headers;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String encodeClientCredential(String clientCredential) {
|
private static String encodeClientCredential(String clientCredential) {
|
||||||
|
@ -230,4 +240,55 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
|
||||||
this.webClient = webClient;
|
this.webClient = webClient;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
|
||||||
|
* used in the OAuth 2.0 Access Token Request headers.
|
||||||
|
* @return the {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
|
||||||
|
*/
|
||||||
|
final Converter<T, HttpHeaders> getHeadersConverter() {
|
||||||
|
return this.headersConverter;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
|
||||||
|
* used in the OAuth 2.0 Access Token Request headers.
|
||||||
|
* @param headersConverter the {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
|
||||||
|
* @since 5.6
|
||||||
|
*/
|
||||||
|
public final void setHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
|
||||||
|
Assert.notNull(headersConverter, "headersConverter cannot be null");
|
||||||
|
this.headersConverter = headersConverter;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add (compose) the provided {@code headersConverter} to the current
|
||||||
|
* {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
|
||||||
|
* used in the OAuth 2.0 Access Token Request headers.
|
||||||
|
* @param headersConverter the {@link Converter} to add (compose) to the current
|
||||||
|
* {@link Converter} used for converting the
|
||||||
|
* {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link HttpHeaders}
|
||||||
|
* @since 5.6
|
||||||
|
*/
|
||||||
|
public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
|
||||||
|
Assert.notNull(headersConverter, "headersConverter cannot be null");
|
||||||
|
Converter<T, HttpHeaders> currentHeadersConverter = this.headersConverter;
|
||||||
|
this.headersConverter = (authorizationGrantRequest) -> {
|
||||||
|
// Append headers using a Composite Converter
|
||||||
|
HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest);
|
||||||
|
if (headers == null) {
|
||||||
|
headers = new HttpHeaders();
|
||||||
|
}
|
||||||
|
HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest);
|
||||||
|
if (headersToAdd != null) {
|
||||||
|
headers.addAll(headersToAdd);
|
||||||
|
}
|
||||||
|
return headers;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2020 the original author or authors.
|
* Copyright 2002-2021 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,15 +17,18 @@
|
||||||
package org.springframework.security.oauth2.client.endpoint;
|
package org.springframework.security.oauth2.client.endpoint;
|
||||||
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
|
import okhttp3.mockwebserver.RecordedRequest;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
@ -340,4 +343,65 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
|
||||||
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
|
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
|
||||||
|
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
|
||||||
|
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(addedHeadersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"openid profile\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(addedHeadersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
|
||||||
|
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
|
||||||
|
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
|
||||||
|
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
|
||||||
|
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(headersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.setHeadersConverter(headersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"openid profile\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(headersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.endpoint;
|
||||||
import java.net.URLEncoder;
|
import java.net.URLEncoder;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
|
@ -27,6 +28,7 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
|
@ -212,4 +214,64 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
|
||||||
this.server.enqueue(response);
|
this.server.enqueue(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
|
||||||
|
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
|
||||||
|
this.clientRegistration.build());
|
||||||
|
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(addedHeadersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.client.addHeadersConverter(addedHeadersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
enqueueJson("{\n"
|
||||||
|
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
|
||||||
|
+ " \"token_type\":\"bearer\",\n"
|
||||||
|
+ " \"expires_in\":3600,\n"
|
||||||
|
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
|
||||||
|
+ "}");
|
||||||
|
// @formatter:on
|
||||||
|
this.client.getTokenResponse(request).block();
|
||||||
|
verify(addedHeadersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
|
||||||
|
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
|
||||||
|
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
|
||||||
|
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
|
||||||
|
this.clientRegistration.build());
|
||||||
|
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(headersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.client.setHeadersConverter(headersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
enqueueJson("{\n"
|
||||||
|
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
|
||||||
|
+ " \"token_type\":\"bearer\",\n"
|
||||||
|
+ " \"expires_in\":3600,\n"
|
||||||
|
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
|
||||||
|
+ "}");
|
||||||
|
// @formatter:on
|
||||||
|
this.client.getTokenResponse(request).block();
|
||||||
|
verify(headersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2020 the original author or authors.
|
* Copyright 2002-2021 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,7 @@
|
||||||
package org.springframework.security.oauth2.client.endpoint;
|
package org.springframework.security.oauth2.client.endpoint;
|
||||||
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
|
@ -25,6 +26,7 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
@ -38,6 +40,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
|
import static org.mockito.BDDMockito.given;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link WebClientReactivePasswordTokenResponseClient}.
|
* Tests for {@link WebClientReactivePasswordTokenResponseClient}.
|
||||||
|
@ -213,4 +218,66 @@ public class WebClientReactivePasswordTokenResponseClientTests {
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
|
||||||
|
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
|
||||||
|
this.username, this.password);
|
||||||
|
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(addedHeadersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"read\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(addedHeadersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
|
||||||
|
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
|
||||||
|
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
|
||||||
|
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
|
||||||
|
this.username, this.password);
|
||||||
|
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(headersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.setHeadersConverter(headersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"read\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(headersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2020 the original author or authors.
|
* Copyright 2002-2021 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,6 +26,7 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import org.springframework.core.convert.converter.Converter;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
@ -42,6 +43,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
|
import static org.mockito.BDDMockito.given;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}.
|
* Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}.
|
||||||
|
@ -217,4 +221,66 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
|
||||||
|
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
|
||||||
|
.withMessage("headersConverter cannot be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
|
||||||
|
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
|
||||||
|
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
|
||||||
|
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(addedHeadersConverter.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"read\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(addedHeadersConverter).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
|
||||||
|
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
|
||||||
|
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
|
||||||
|
OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(
|
||||||
|
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
|
||||||
|
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
|
||||||
|
final HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
|
||||||
|
given(headersConverter1.convert(request)).willReturn(headers);
|
||||||
|
this.tokenResponseClient.setHeadersConverter(headersConverter1);
|
||||||
|
// @formatter:off
|
||||||
|
String accessTokenSuccessResponse = "{\n"
|
||||||
|
+ " \"access_token\": \"access-token-1234\",\n"
|
||||||
|
+ " \"token_type\": \"bearer\",\n"
|
||||||
|
+ " \"expires_in\": \"3600\",\n"
|
||||||
|
+ " \"scope\": \"read\"\n"
|
||||||
|
+ "}\n";
|
||||||
|
// @formatter:on
|
||||||
|
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
|
||||||
|
this.tokenResponseClient.getTokenResponse(request).block();
|
||||||
|
verify(headersConverter1).convert(request);
|
||||||
|
RecordedRequest actualRequest = this.server.takeRequest();
|
||||||
|
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue