From 0116c65c0e46674ea5a4ac4d76a7d91a382ee7a3 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Mon, 2 Jul 2018 11:29:39 -0500 Subject: [PATCH] OAuth2AuthorizedClientExchangeFilterFunction Refresh Support --- ...uthorizedClientExchangeFilterFunction.java | 190 ++++++++++++++- ...izedClientExchangeFilterFunctionTests.java | 216 +++++++++++++++++- 2 files changed, 393 insertions(+), 13 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunction.java index 8df207778a..45212eb593 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,20 +16,40 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import static org.springframework.security.web.http.SecurityHeaders.bearerToken; - -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; - +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; - import reactor.core.publisher.Mono; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import static org.springframework.security.web.http.SecurityHeaders.bearerToken; + /** * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * token as a Bearer Token. @@ -43,12 +63,27 @@ public final class OAuth2AuthorizedClientExchangeFilterFunction implements Excha */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); + private Clock clock = Clock.systemUTC(); + + private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + + public OAuth2AuthorizedClientExchangeFilterFunction() {} + + public OAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) { + this.authorizedClientService = authorizedClientService; + } + /** * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for * providing the Bearer Token. Example usage: * *
-	 * Mono response = this.webClient
+	 * WebClient webClient = WebClient.builder()
+	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
+	 *    .build();
+	 * Mono response = webClient
 	 *    .get()
 	 *    .uri(uri)
 	 *    .attributes(oauth2AuthorizedClient(authorizedClient))
@@ -57,6 +92,20 @@ public final class OAuth2AuthorizedClientExchangeFilterFunction implements Excha
 	 *    .bodyToMono(String.class);
 	 * 
* + * An attempt to automatically refresh the token will be made if all of the following + * are true: + * + * + * * @param authorizedClient the {@link OAuth2AuthorizedClient} to use. * @return the {@link Consumer} to populate the */ @@ -64,14 +113,79 @@ public final class OAuth2AuthorizedClientExchangeFilterFunction implements Excha return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); } + /** + * An access token will be considered expired by comparing its expiration to now + + * this skewed Duration. The default is 1 minute. + * @param accessTokenExpiresSkew the Duration to use. + */ + public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { + Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); + this.accessTokenExpiresSkew = accessTokenExpiresSkew; + } + @Override public Mono filter(ClientRequest request, ExchangeFunction next) { Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) .map(OAuth2AuthorizedClient.class::cast); - return attribute + return Mono.justOrEmpty(attribute) + .flatMap(authorizedClient -> authorizedClient(next, authorizedClient)) .map(authorizedClient -> bearer(request, authorizedClient)) - .map(next::exchange) - .orElseGet(() -> next.exchange(request)); + .flatMap(next::exchange) + .switchIfEmpty(next.exchange(request)); + } + + private Mono authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { + if (shouldRefresh(authorizedClient)) { + return refreshAuthorizedClient(next, authorizedClient); + } + return Mono.just(authorizedClient); + } + + private Mono refreshAuthorizedClient(ExchangeFunction next, + OAuth2AuthorizedClient authorizedClient) { + ClientRegistration clientRegistration = authorizedClient + .getClientRegistration(); + String tokenUri = clientRegistration + .getProviderDetails().getTokenUri(); + ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .headers(httpBasic(clientRegistration.getClientId(), clientRegistration.getClientSecret())) + .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) + .build(); + return next.exchange(request) + .flatMap(response -> response.body(oauth2AccessTokenResponse())) + .map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())) + .flatMap(result -> ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())) + .flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal)) + .thenReturn(result)); + } + + private static Consumer httpBasic(String username, String password) { + return httpHeaders -> { + String credentialsString = username + ":" + password; + byte[] credentialBytes = credentialsString.getBytes(StandardCharsets.ISO_8859_1); + byte[] encodedBytes = Base64.getEncoder().encode(credentialBytes); + String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1); + httpHeaders.set(HttpHeaders.AUTHORIZATION, "Basic " + encodedCredentials); + }; + } + + private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { + if (this.authorizedClientService == null) { + return false; + } + OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); + if (refreshToken == null) { + return false; + } + Instant now = this.clock.instant(); + Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); + if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { + return true; + } + return false; } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { @@ -79,4 +193,58 @@ public final class OAuth2AuthorizedClientExchangeFilterFunction implements Excha .headers(bearerToken(authorizedClient.getAccessToken().getTokenValue())) .build(); } + + private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) { + return BodyInserters + .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) + .with("refresh_token", refreshToken); + } + + private static class PrincipalNameAuthentication implements Authentication { + private final String username; + + private PrincipalNameAuthentication(String username) { + this.username = username; + } + + @Override + public Collection getAuthorities() { + throw unsupported(); + } + + @Override + public Object getCredentials() { + throw unsupported(); + } + + @Override + public Object getDetails() { + throw unsupported(); + } + + @Override + public Object getPrincipal() { + throw unsupported(); + } + + @Override + public boolean isAuthenticated() { + throw unsupported(); + } + + @Override + public void setAuthenticated(boolean isAuthenticated) + throws IllegalArgumentException { + throw unsupported(); + } + + @Override + public String getName() { + return this.username; + } + + private UnsupportedOperationException unsupported() { + return new UnsupportedOperationException("Not Supported"); + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunctionTests.java index f313e5d42d..8a5929f855 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -17,19 +17,51 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageWriter; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; +import reactor.core.publisher.Mono; import java.net.URI; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; @@ -37,7 +69,11 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c * @author Rob Winch * @since 5.1 */ +@RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizedClientExchangeFilterFunctionTests { + @Mock + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientExchangeFilterFunction function = new OAuth2AuthorizedClientExchangeFilterFunction(); private MockExchangeFunction exchange = new MockExchangeFunction(); @@ -57,7 +93,7 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests { .build(); private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", + "token-0", Instant.now(), Instant.now().plus(Duration.ofDays(1))); @@ -98,4 +134,180 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests { HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } + + @Test + public void filterWhenRefreshRequiredThenRefresh() { + when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .refreshToken("refresh-1") + .build(); + when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new OAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test","this"); + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication)); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(2); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token"); + assertThat(request0.method()).isEqualTo(HttpMethod.POST); + assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); + + ClientRequest request1 = requests.get(1); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { + when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .refreshToken("refresh-1") + .build(); + when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new OAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange) + .block(); + + verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(2); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token"); + assertThat(request0.method()).isEqualTo(HttpMethod.POST); + assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); + + ClientRequest request1 = requests.get(1); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { + this.function = new OAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + + @Test + public void filterWhenNotExpiredThenShouldRefreshFalse() { + this.function = new OAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + + private static String getBody(ClientRequest request) { + final List> messageWriters = new ArrayList<>(); + messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder())); + messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly())); + messageWriters.add(new ResourceHttpMessageWriter()); + messageWriters.add(new EncoderHttpMessageWriter<>(new Jaxb2XmlEncoder())); + Jackson2JsonEncoder jsonEncoder = new Jackson2JsonEncoder(); + messageWriters.add(new EncoderHttpMessageWriter<>(jsonEncoder)); + messageWriters.add(new ServerSentEventHttpMessageWriter(jsonEncoder)); + messageWriters.add(new FormHttpMessageWriter()); + messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes())); + messageWriters.add(new MultipartHttpMessageWriter(messageWriters)); + + BodyInserter.Context context = new BodyInserter.Context() { + @Override + public List> messageWriters() { + return messageWriters; + } + + @Override + public Optional serverRequest() { + return Optional.empty(); + } + + @Override + public Map hints() { + return new HashMap<>(); + } + }; + + MockClientHttpRequest body = new MockClientHttpRequest(HttpMethod.GET, "/"); + request.body().insert(body, context).block(); + return body.getBodyAsString().block(); + } }