From 8a0ed3b96bc8197f0dde4b21b33dfd0e87bde261 Mon Sep 17 00:00:00 2001 From: DOHA Date: Thu, 26 Feb 2015 15:56:16 +0200 Subject: [PATCH] Add exception handler --- .../config/CustomRequestEnhancer.java | 22 --- ...yAuthorizationCodeAccessTokenProvider.java | 126 +----------------- .../java/org/baeldung/config/WebConfig.java | 12 +- .../org/baeldung/web/RedditController.java | 92 ++++--------- .../baeldung/web/RestExceptionHandler.java | 42 ++++++ 5 files changed, 82 insertions(+), 212 deletions(-) delete mode 100644 spring-security-oauth/src/main/java/org/baeldung/config/CustomRequestEnhancer.java create mode 100644 spring-security-oauth/src/main/java/org/baeldung/web/RestExceptionHandler.java diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/CustomRequestEnhancer.java b/spring-security-oauth/src/main/java/org/baeldung/config/CustomRequestEnhancer.java deleted file mode 100644 index cb213719af..0000000000 --- a/spring-security-oauth/src/main/java/org/baeldung/config/CustomRequestEnhancer.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.baeldung.config; - -import java.io.Serializable; - -import org.springframework.http.HttpHeaders; -import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails; -import org.springframework.security.oauth2.client.token.AccessTokenRequest; -import org.springframework.security.oauth2.client.token.RequestEnhancer; -import org.springframework.util.MultiValueMap; - -public class CustomRequestEnhancer implements RequestEnhancer, Serializable { - /** - * - */ - private static final long serialVersionUID = 1L; - - @Override - public void enhance(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource, MultiValueMap form, HttpHeaders headers) { - System.out.println("called==="); - form.set("duration", "permanent"); - } -} diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/MyAuthorizationCodeAccessTokenProvider.java b/spring-security-oauth/src/main/java/org/baeldung/config/MyAuthorizationCodeAccessTokenProvider.java index 47ff588fc0..ca37ab3f82 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/config/MyAuthorizationCodeAccessTokenProvider.java +++ b/spring-security-oauth/src/main/java/org/baeldung/config/MyAuthorizationCodeAccessTokenProvider.java @@ -1,17 +1,11 @@ package org.baeldung.config; -import java.io.IOException; -import java.net.URI; +import java.io.Serializable; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.TreeMap; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpResponse; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.oauth2.client.filter.state.DefaultStateKeyGenerator; import org.springframework.security.oauth2.client.filter.state.StateKeyGenerator; @@ -20,86 +14,22 @@ import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResour import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException; import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; import org.springframework.security.oauth2.client.token.AccessTokenRequest; -import org.springframework.security.oauth2.client.token.DefaultRequestEnhancer; -import org.springframework.security.oauth2.client.token.RequestEnhancer; import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeResourceDetails; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.InvalidRequestException; -import org.springframework.security.oauth2.common.util.OAuth2Utils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.web.client.ResponseExtractor; -import com.google.common.base.Joiner; +public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAccessTokenProvider implements Serializable { -public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAccessTokenProvider { + /** + * + */ + private static final long serialVersionUID = 3822611002661972274L; private StateKeyGenerator stateKeyGenerator = new DefaultStateKeyGenerator(); - private String scopePrefix = OAuth2Utils.SCOPE_PREFIX; - - private RequestEnhancer authorizationRequestEnhancer = new DefaultRequestEnhancer(); - - @Override - public String obtainAuthorizationCode(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException { - AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details; - - HttpHeaders headers = getHeadersForAuthorizationRequest(request); - MultiValueMap form = new LinkedMultiValueMap(); - if (request.containsKey(OAuth2Utils.USER_OAUTH_APPROVAL)) { - form.set(OAuth2Utils.USER_OAUTH_APPROVAL, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL)); - for (String scope : details.getScope()) { - form.set(scopePrefix + scope, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL)); - } - } else { - form.putAll(getParametersForAuthorizeRequest(resource, request)); - } - form.set("duration", "permanent"); - System.out.println("===== at enhancer point ==="); - authorizationRequestEnhancer.enhance(request, resource, form, headers); - final AccessTokenRequest copy = request; - - final ResponseExtractor> delegate = getAuthorizationResponseExtractor(); - ResponseExtractor> extractor = new ResponseExtractor>() { - @Override - public ResponseEntity extractData(ClientHttpResponse response) throws IOException { - if (response.getHeaders().containsKey("Set-Cookie")) { - copy.setCookie(response.getHeaders().getFirst("Set-Cookie")); - } - return delegate.extractData(response); - } - }; - - ResponseEntity response = getRestTemplate().execute(resource.getUserAuthorizationUri(), HttpMethod.POST, getRequestCallback(resource, form, headers), extractor, form.toSingleValueMap()); - - if (response.getStatusCode() == HttpStatus.OK) { - throw getUserApprovalSignal(resource, request); - } - - URI location = response.getHeaders().getLocation(); - String query = location.getQuery(); - Map map = OAuth2Utils.extractMap(query); - if (map.containsKey("state")) { - request.setStateKey(map.get("state")); - if (request.getPreservedState() == null) { - String redirectUri = resource.getRedirectUri(request); - if (redirectUri != null) { - request.setPreservedState(redirectUri); - } else { - request.setPreservedState(new Object()); - } - } - } - - String code = map.get("code"); - if (code == null) { - throw new UserRedirectRequiredException(location.toString(), form.toSingleValueMap()); - } - request.set("code", code); - return code; - } - @Override public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException { AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details; @@ -118,15 +48,6 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc return headers; } - private HttpHeaders getHeadersForAuthorizationRequest(AccessTokenRequest request) { - HttpHeaders headers = new HttpHeaders(); - headers.putAll(request.getHeaders()); - if (request.getCookie() != null) { - headers.set("Cookie", request.getCookie()); - } - return headers; - } - private MultiValueMap getParametersForTokenRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) { MultiValueMap form = new LinkedMultiValueMap(); form.set("grant_type", "authorization_code"); @@ -153,41 +74,6 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc return form; } - private MultiValueMap getParametersForAuthorizeRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) { - MultiValueMap form = new LinkedMultiValueMap(); - form.set("response_type", "code"); - form.set("client_id", resource.getClientId()); - - if (request.get("scope") != null) { - form.set("scope", request.getFirst("scope")); - } else { - form.set("scope", Joiner.on(',').join(resource.getScope())); - } - - String redirectUri = resource.getPreEstablishedRedirectUri(); - - Object preservedState = request.getPreservedState(); - if (redirectUri == null && preservedState != null) { - redirectUri = String.valueOf(preservedState); - } else { - redirectUri = request.getCurrentUri(); - } - - String stateKey = request.getStateKey(); - if (stateKey != null) { - form.set("state", stateKey); - if (preservedState == null) { - throw new InvalidRequestException("Possible CSRF detected - state parameter was present but no state could be found"); - } - } - - if (redirectUri != null) { - form.set("redirect_uri", redirectUri); - } - - return form; - } - private UserRedirectRequiredException getRedirectForAuthorization(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) { TreeMap requestParameters = new TreeMap(); requestParameters.put("response_type", "code"); diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java b/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java index f7167eaee6..bd7a11820b 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java +++ b/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java @@ -3,6 +3,7 @@ package org.baeldung.config; import java.util.Arrays; import org.baeldung.web.RedditController; +import org.baeldung.web.RestExceptionHandler; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; @@ -15,7 +16,6 @@ import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResour import org.springframework.security.oauth2.client.token.AccessTokenProvider; import org.springframework.security.oauth2.client.token.AccessTokenProviderChain; import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider; -import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeResourceDetails; import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider; import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider; @@ -56,6 +56,11 @@ public class WebConfig extends WebMvcConfigurerAdapter { return controller; } + @Bean + public RestExceptionHandler restExceptionHandler() { + return new RestExceptionHandler(); + } + public void addResourceHandlers(ResourceHandlerRegistry registry) { registry.addResourceHandler("/resources/**").addResourceLocations("/resources/"); } @@ -94,9 +99,8 @@ public class WebConfig extends WebMvcConfigurerAdapter { @Bean public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) { OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext); - AuthorizationCodeAccessTokenProvider authProvider = new AuthorizationCodeAccessTokenProvider(); - authProvider.setAuthorizationRequestEnhancer(new CustomRequestEnhancer()); - AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays. asList(authProvider, new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider())); + AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays. asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), + new ClientCredentialsAccessTokenProvider())); template.setAccessTokenProvider(accessTokenProvider); return template; } diff --git a/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java b/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java index ef80ff1747..869553448b 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java +++ b/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java @@ -8,13 +8,7 @@ import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; import org.springframework.security.oauth2.client.OAuth2RestTemplate; -import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException; -import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; import org.springframework.stereotype.Controller; import org.springframework.ui.Model; import org.springframework.util.LinkedMultiValueMap; @@ -34,75 +28,46 @@ public class RedditController { @RequestMapping("/info") public String getInfo(Model model) { - try { - String result = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", String.class); - JsonNode node = new ObjectMapper().readTree(result); - String name = node.get("name").asText(); - model.addAttribute("info", name); - } catch (UserApprovalRequiredException e) { - throw e; - } catch (UserRedirectRequiredException e) { - throw e; - } catch (Exception e) { - LOGGER.error("Error occurred", e); - model.addAttribute("error", e.getMessage()); - } + JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class); + String name = node.get("name").asText(); + model.addAttribute("info", name); return "reddit"; } @RequestMapping("/submit") public String submit(Model model, @RequestParam Map formParams) { - try { - MultiValueMap param = new LinkedMultiValueMap(); - param.add("api_type", "json"); - param.add("kind", "link"); - param.add("resubmit", "true"); - param.add("sendreplies", "false"); - param.add("then", "comments"); + MultiValueMap param = new LinkedMultiValueMap(); + param.add("api_type", "json"); + param.add("kind", "link"); + param.add("resubmit", "true"); + param.add("sendreplies", "false"); + param.add("then", "comments"); - for (Map.Entry entry : formParams.entrySet()) { - param.add(entry.getKey(), entry.getValue()); - } - - LOGGER.info("User submitting Link with these parameters: " + formParams.entrySet()); - String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, String.class); - LOGGER.info("Full Reddit Response: " + result); - String responseMsg = parseResponse(result); - model.addAttribute("msg", responseMsg); - } catch (UserApprovalRequiredException e) { - throw e; - } catch (UserRedirectRequiredException e) { - throw e; - } catch (Exception e) { - LOGGER.error("Error occurred", e); - model.addAttribute("msg", e.getLocalizedMessage()); + for (Map.Entry entry : formParams.entrySet()) { + param.add(entry.getKey(), entry.getValue()); } + + LOGGER.info("User submitting Link with these parameters: " + formParams.entrySet()); + JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class); + LOGGER.info("Full Reddit Response: " + node.toString()); + String responseMsg = parseResponse(node); + model.addAttribute("msg", responseMsg); return "submissionResponse"; } @RequestMapping("/post") - public String showSubmissionForm(Model model) throws JsonProcessingException, IOException { - try { - String needsCaptchaResult = needsCaptcha(); - if (needsCaptchaResult.equalsIgnoreCase("true")) { - String iden = getNewCaptcha(); - model.addAttribute("iden", iden); - } - } catch (UserApprovalRequiredException e) { - throw e; - } catch (UserRedirectRequiredException e) { - throw e; - } catch (Exception e) { - LOGGER.error("Error occurred", e); - model.addAttribute("error", e.getLocalizedMessage()); - return "reddit"; + public String showSubmissionForm(Model model) { + String needsCaptchaResult = needsCaptcha(); + if (needsCaptchaResult.equalsIgnoreCase("true")) { + String iden = getNewCaptcha(); + model.addAttribute("iden", iden); } return "submissionForm"; } // === private - private List getSubreddit() throws JsonProcessingException, IOException { + public List getSubreddit() throws JsonProcessingException, IOException { String result = redditRestTemplate.getForObject("https://oauth.reddit.com/subreddits/popular?limit=50", String.class); JsonNode node = new ObjectMapper().readTree(result); node = node.get("data").get("children"); @@ -119,21 +84,16 @@ public class RedditController { } private String getNewCaptcha() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - HttpEntity req = new HttpEntity(headers); - Map param = new HashMap(); param.put("api_type", "json"); - ResponseEntity result = redditRestTemplate.postForEntity("https://oauth.reddit.com/api/new_captcha", req, String.class, param); - String[] split = result.getBody().split("\""); + String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/new_captcha", param, String.class, param); + String[] split = result.split("\""); return split[split.length - 2]; } - private String parseResponse(String responseBody) throws JsonProcessingException, IOException { + private String parseResponse(JsonNode node) { String result = ""; - JsonNode node = new ObjectMapper().readTree(responseBody); JsonNode errorNode = node.get("json").get("errors").get(0); if (errorNode != null) { for (JsonNode child : errorNode) { diff --git a/spring-security-oauth/src/main/java/org/baeldung/web/RestExceptionHandler.java b/spring-security-oauth/src/main/java/org/baeldung/web/RestExceptionHandler.java new file mode 100644 index 0000000000..0a1c671a0b --- /dev/null +++ b/spring-security-oauth/src/main/java/org/baeldung/web/RestExceptionHandler.java @@ -0,0 +1,42 @@ +package org.baeldung.web; + +import java.io.Serializable; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException; +import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.context.request.WebRequest; +import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler; + +@ControllerAdvice +public class RestExceptionHandler extends ResponseEntityExceptionHandler implements Serializable { + /** + * + */ + private static final long serialVersionUID = -3861125729653781371L; + + public RestExceptionHandler() { + super(); + } + + // API + + // 500 + @ExceptionHandler({ UserApprovalRequiredException.class, UserRedirectRequiredException.class }) + public ResponseEntity handleRedirect(final RuntimeException ex, final WebRequest request) { + logger.error("500 Status Code", ex); + throw ex; + } + + @ExceptionHandler({ Exception.class }) + public ResponseEntity handleInternal(final RuntimeException ex, final WebRequest request) { + logger.info(request.getHeader("x-ratelimit-remaining")); + logger.error("500 Status Code", ex); + String response = "Error Occurred : " + ex.getMessage(); + return handleExceptionInternal(ex, response, new HttpHeaders(), HttpStatus.INTERNAL_SERVER_ERROR, request); + } +} \ No newline at end of file