From 0d6007309574c5e29c255fc910aedf9dfc8934a4 Mon Sep 17 00:00:00 2001 From: DOHA Date: Sat, 24 Nov 2018 13:36:14 +0200 Subject: [PATCH] customise oauth2 requests --- .../CustomAuthorizationRequestResolver.java | 50 ++++++++++++++ .../oauth2/CustomRequestEntityConverter.java | 26 +++++++ .../oauth2/CustomTokenResponseConverter.java | 67 +++++++++++++++++++ .../com/baeldung/oauth2/SecurityConfig.java | 18 ++++- 4 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 spring-5-security/src/main/java/com/baeldung/oauth2/CustomAuthorizationRequestResolver.java create mode 100644 spring-5-security/src/main/java/com/baeldung/oauth2/CustomRequestEntityConverter.java create mode 100644 spring-5-security/src/main/java/com/baeldung/oauth2/CustomTokenResponseConverter.java diff --git a/spring-5-security/src/main/java/com/baeldung/oauth2/CustomAuthorizationRequestResolver.java b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomAuthorizationRequestResolver.java new file mode 100644 index 0000000000..025064423d --- /dev/null +++ b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomAuthorizationRequestResolver.java @@ -0,0 +1,50 @@ +package com.baeldung.oauth2; + +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; + +public class CustomAuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver { + + private OAuth2AuthorizationRequestResolver defaultResolver; + + public CustomAuthorizationRequestResolver(ClientRegistrationRepository repo, String authorizationRequestBaseUri){ + defaultResolver = new DefaultOAuth2AuthorizationRequestResolver(repo, authorizationRequestBaseUri); + } + + @Override + public OAuth2AuthorizationRequest resolve(HttpServletRequest request) { + OAuth2AuthorizationRequest req = defaultResolver.resolve(request); + if(req != null){ + req = customizeAuthorizationRequest(req); + } + return req; + } + + @Override + public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId) { + OAuth2AuthorizationRequest req = defaultResolver.resolve(request, clientRegistrationId); + if(req != null){ + req = customizeAuthorizationRequest(req); + } + return req; + } + + private OAuth2AuthorizationRequest customizeAuthorizationRequest(OAuth2AuthorizationRequest req) { + Map extraParams = new HashMap(); + extraParams.putAll(req.getAdditionalParameters()); //VIP note + extraParams.put("test", "extra"); + System.out.println("here ====================="); + return OAuth2AuthorizationRequest.from(req).additionalParameters(extraParams).build(); + } + + private OAuth2AuthorizationRequest customizeAuthorizationRequest1(OAuth2AuthorizationRequest req) { + return OAuth2AuthorizationRequest.from(req).state("xyz").build(); + } +} diff --git a/spring-5-security/src/main/java/com/baeldung/oauth2/CustomRequestEntityConverter.java b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomRequestEntityConverter.java new file mode 100644 index 0000000000..8884065769 --- /dev/null +++ b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomRequestEntityConverter.java @@ -0,0 +1,26 @@ +package com.baeldung.oauth2; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequestEntityConverter; +import org.springframework.util.MultiValueMap; + +public class CustomRequestEntityConverter implements Converter> { + + private OAuth2AuthorizationCodeGrantRequestEntityConverter defaultConverter; + + public CustomRequestEntityConverter() { + defaultConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + } + + @Override + public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest req) { + RequestEntity entity = defaultConverter.convert(req); + MultiValueMap params = (MultiValueMap) entity.getBody(); + params.add("test2", "extra2"); + System.out.println(params.entrySet()); + return new RequestEntity<>(params, entity.getHeaders(), entity.getMethod(), entity.getUrl()); + } + +} diff --git a/spring-5-security/src/main/java/com/baeldung/oauth2/CustomTokenResponseConverter.java b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomTokenResponseConverter.java new file mode 100644 index 0000000000..741f44871a --- /dev/null +++ b/spring-5-security/src/main/java/com/baeldung/oauth2/CustomTokenResponseConverter.java @@ -0,0 +1,67 @@ +package com.baeldung.oauth2; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.StringUtils; + +public class CustomTokenResponseConverter implements Converter, OAuth2AccessTokenResponse> { + private static final Set TOKEN_RESPONSE_PARAMETER_NAMES = Stream.of( + OAuth2ParameterNames.ACCESS_TOKEN, + OAuth2ParameterNames.TOKEN_TYPE, + OAuth2ParameterNames.EXPIRES_IN, + OAuth2ParameterNames.REFRESH_TOKEN, + OAuth2ParameterNames.SCOPE) .collect(Collectors.toSet()); + + @Override + public OAuth2AccessTokenResponse convert(Map tokenResponseParameters) { + String accessToken = tokenResponseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN); + + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(tokenResponseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; + } + + long expiresIn = 0; + if (tokenResponseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) { + try { + expiresIn = Long.valueOf(tokenResponseParameters.get(OAuth2ParameterNames.EXPIRES_IN)); + } catch (NumberFormatException ex) { + } + } + + Set scopes = Collections.emptySet(); + if (tokenResponseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { + String scope = tokenResponseParameters.get(OAuth2ParameterNames.SCOPE); + scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")) + .collect(Collectors.toSet()); + } + + String refreshToken = tokenResponseParameters.get(OAuth2ParameterNames.REFRESH_TOKEN); + + Map additionalParameters = new LinkedHashMap<>(); + tokenResponseParameters.entrySet() + .stream() + .filter(e -> !TOKEN_RESPONSE_PARAMETER_NAMES.contains(e.getKey())) + .forEach(e -> additionalParameters.put(e.getKey(), e.getValue())); + + return OAuth2AccessTokenResponse.withToken(accessToken) + .tokenType(accessTokenType) + .expiresIn(expiresIn) + .scopes(scopes) + .refreshToken(refreshToken) + .additionalParameters(additionalParameters) + .build(); + } + +} diff --git a/spring-5-security/src/main/java/com/baeldung/oauth2/SecurityConfig.java b/spring-5-security/src/main/java/com/baeldung/oauth2/SecurityConfig.java index b45f325767..cf27b01a75 100644 --- a/spring-5-security/src/main/java/com/baeldung/oauth2/SecurityConfig.java +++ b/spring-5-security/src/main/java/com/baeldung/oauth2/SecurityConfig.java @@ -9,18 +9,22 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.PropertySource; import org.springframework.core.env.Environment; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; -import org.springframework.security.oauth2.client.endpoint.NimbusAuthorizationCodeTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.web.client.RestTemplate; @Configuration @PropertySource("application-oauth2.properties") @@ -37,6 +41,8 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { .oauth2Login() .loginPage("/oauth_login") .authorizationEndpoint() + .authorizationRequestResolver( new CustomAuthorizationRequestResolver(clientRegistrationRepository(),"/oauth2/authorize-client")) + .baseUri("/oauth2/authorize-client") .authorizationRequestRepository(authorizationRequestRepository()) .and() @@ -54,7 +60,15 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Bean public OAuth2AccessTokenResponseClient accessTokenResponseClient() { - return new NimbusAuthorizationCodeTokenResponseClient(); + DefaultAuthorizationCodeTokenResponseClient accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); + accessTokenResponseClient.setRequestEntityConverter(new CustomRequestEntityConverter()); + + OAuth2AccessTokenResponseHttpMessageConverter tokenResponseHttpMessageConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); + tokenResponseHttpMessageConverter.setTokenResponseConverter(new CustomTokenResponseConverter()); + RestTemplate restTemplate = new RestTemplate(Arrays.asList(new FormHttpMessageConverter(), tokenResponseHttpMessageConverter)); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + accessTokenResponseClient.setRestOperations(restTemplate); + return accessTokenResponseClient; }