Add exception handler
This commit is contained in:
parent
5dedba216f
commit
8a0ed3b96b
|
@ -1,22 +0,0 @@
|
|||
package org.baeldung.config;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
|
||||
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
|
||||
import org.springframework.security.oauth2.client.token.RequestEnhancer;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
public class CustomRequestEnhancer implements RequestEnhancer, Serializable {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
@Override
|
||||
public void enhance(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource, MultiValueMap<String, String> form, HttpHeaders headers) {
|
||||
System.out.println("called===");
|
||||
form.set("duration", "permanent");
|
||||
}
|
||||
}
|
|
@ -1,17 +1,11 @@
|
|||
package org.baeldung.config;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.io.Serializable;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.http.client.ClientHttpResponse;
|
||||
import org.springframework.security.access.AccessDeniedException;
|
||||
import org.springframework.security.oauth2.client.filter.state.DefaultStateKeyGenerator;
|
||||
import org.springframework.security.oauth2.client.filter.state.StateKeyGenerator;
|
||||
|
@ -20,86 +14,22 @@ import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResour
|
|||
import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException;
|
||||
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
|
||||
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
|
||||
import org.springframework.security.oauth2.client.token.DefaultRequestEnhancer;
|
||||
import org.springframework.security.oauth2.client.token.RequestEnhancer;
|
||||
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
|
||||
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeResourceDetails;
|
||||
import org.springframework.security.oauth2.common.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
|
||||
import org.springframework.security.oauth2.common.util.OAuth2Utils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.ResponseExtractor;
|
||||
|
||||
import com.google.common.base.Joiner;
|
||||
public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAccessTokenProvider implements Serializable {
|
||||
|
||||
public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAccessTokenProvider {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = 3822611002661972274L;
|
||||
|
||||
private StateKeyGenerator stateKeyGenerator = new DefaultStateKeyGenerator();
|
||||
|
||||
private String scopePrefix = OAuth2Utils.SCOPE_PREFIX;
|
||||
|
||||
private RequestEnhancer authorizationRequestEnhancer = new DefaultRequestEnhancer();
|
||||
|
||||
@Override
|
||||
public String obtainAuthorizationCode(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException {
|
||||
AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
|
||||
|
||||
HttpHeaders headers = getHeadersForAuthorizationRequest(request);
|
||||
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
|
||||
if (request.containsKey(OAuth2Utils.USER_OAUTH_APPROVAL)) {
|
||||
form.set(OAuth2Utils.USER_OAUTH_APPROVAL, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL));
|
||||
for (String scope : details.getScope()) {
|
||||
form.set(scopePrefix + scope, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL));
|
||||
}
|
||||
} else {
|
||||
form.putAll(getParametersForAuthorizeRequest(resource, request));
|
||||
}
|
||||
form.set("duration", "permanent");
|
||||
System.out.println("===== at enhancer point ===");
|
||||
authorizationRequestEnhancer.enhance(request, resource, form, headers);
|
||||
final AccessTokenRequest copy = request;
|
||||
|
||||
final ResponseExtractor<ResponseEntity<Void>> delegate = getAuthorizationResponseExtractor();
|
||||
ResponseExtractor<ResponseEntity<Void>> extractor = new ResponseExtractor<ResponseEntity<Void>>() {
|
||||
@Override
|
||||
public ResponseEntity<Void> extractData(ClientHttpResponse response) throws IOException {
|
||||
if (response.getHeaders().containsKey("Set-Cookie")) {
|
||||
copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
|
||||
}
|
||||
return delegate.extractData(response);
|
||||
}
|
||||
};
|
||||
|
||||
ResponseEntity<Void> response = getRestTemplate().execute(resource.getUserAuthorizationUri(), HttpMethod.POST, getRequestCallback(resource, form, headers), extractor, form.toSingleValueMap());
|
||||
|
||||
if (response.getStatusCode() == HttpStatus.OK) {
|
||||
throw getUserApprovalSignal(resource, request);
|
||||
}
|
||||
|
||||
URI location = response.getHeaders().getLocation();
|
||||
String query = location.getQuery();
|
||||
Map<String, String> map = OAuth2Utils.extractMap(query);
|
||||
if (map.containsKey("state")) {
|
||||
request.setStateKey(map.get("state"));
|
||||
if (request.getPreservedState() == null) {
|
||||
String redirectUri = resource.getRedirectUri(request);
|
||||
if (redirectUri != null) {
|
||||
request.setPreservedState(redirectUri);
|
||||
} else {
|
||||
request.setPreservedState(new Object());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
String code = map.get("code");
|
||||
if (code == null) {
|
||||
throw new UserRedirectRequiredException(location.toString(), form.toSingleValueMap());
|
||||
}
|
||||
request.set("code", code);
|
||||
return code;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException {
|
||||
AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
|
||||
|
@ -118,15 +48,6 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
|
|||
return headers;
|
||||
}
|
||||
|
||||
private HttpHeaders getHeadersForAuthorizationRequest(AccessTokenRequest request) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.putAll(request.getHeaders());
|
||||
if (request.getCookie() != null) {
|
||||
headers.set("Cookie", request.getCookie());
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
private MultiValueMap<String, String> getParametersForTokenRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) {
|
||||
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
|
||||
form.set("grant_type", "authorization_code");
|
||||
|
@ -153,41 +74,6 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
|
|||
return form;
|
||||
}
|
||||
|
||||
private MultiValueMap<String, String> getParametersForAuthorizeRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) {
|
||||
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
|
||||
form.set("response_type", "code");
|
||||
form.set("client_id", resource.getClientId());
|
||||
|
||||
if (request.get("scope") != null) {
|
||||
form.set("scope", request.getFirst("scope"));
|
||||
} else {
|
||||
form.set("scope", Joiner.on(',').join(resource.getScope()));
|
||||
}
|
||||
|
||||
String redirectUri = resource.getPreEstablishedRedirectUri();
|
||||
|
||||
Object preservedState = request.getPreservedState();
|
||||
if (redirectUri == null && preservedState != null) {
|
||||
redirectUri = String.valueOf(preservedState);
|
||||
} else {
|
||||
redirectUri = request.getCurrentUri();
|
||||
}
|
||||
|
||||
String stateKey = request.getStateKey();
|
||||
if (stateKey != null) {
|
||||
form.set("state", stateKey);
|
||||
if (preservedState == null) {
|
||||
throw new InvalidRequestException("Possible CSRF detected - state parameter was present but no state could be found");
|
||||
}
|
||||
}
|
||||
|
||||
if (redirectUri != null) {
|
||||
form.set("redirect_uri", redirectUri);
|
||||
}
|
||||
|
||||
return form;
|
||||
}
|
||||
|
||||
private UserRedirectRequiredException getRedirectForAuthorization(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) {
|
||||
TreeMap<String, String> requestParameters = new TreeMap<String, String>();
|
||||
requestParameters.put("response_type", "code");
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.baeldung.config;
|
|||
import java.util.Arrays;
|
||||
|
||||
import org.baeldung.web.RedditController;
|
||||
import org.baeldung.web.RestExceptionHandler;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
@ -15,7 +16,6 @@ import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResour
|
|||
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
|
||||
import org.springframework.security.oauth2.client.token.AccessTokenProviderChain;
|
||||
import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider;
|
||||
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
|
||||
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeResourceDetails;
|
||||
import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider;
|
||||
import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider;
|
||||
|
@ -56,6 +56,11 @@ public class WebConfig extends WebMvcConfigurerAdapter {
|
|||
return controller;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public RestExceptionHandler restExceptionHandler() {
|
||||
return new RestExceptionHandler();
|
||||
}
|
||||
|
||||
public void addResourceHandlers(ResourceHandlerRegistry registry) {
|
||||
registry.addResourceHandler("/resources/**").addResourceLocations("/resources/");
|
||||
}
|
||||
|
@ -94,9 +99,8 @@ public class WebConfig extends WebMvcConfigurerAdapter {
|
|||
@Bean
|
||||
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) {
|
||||
OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext);
|
||||
AuthorizationCodeAccessTokenProvider authProvider = new AuthorizationCodeAccessTokenProvider();
|
||||
authProvider.setAuthorizationRequestEnhancer(new CustomRequestEnhancer());
|
||||
AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(authProvider, new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider()));
|
||||
AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
|
||||
new ClientCredentialsAccessTokenProvider()));
|
||||
template.setAccessTokenProvider(accessTokenProvider);
|
||||
return template;
|
||||
}
|
||||
|
|
|
@ -8,13 +8,7 @@ import java.util.Map;
|
|||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
|
||||
import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException;
|
||||
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.ui.Model;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
|
@ -34,75 +28,46 @@ public class RedditController {
|
|||
|
||||
@RequestMapping("/info")
|
||||
public String getInfo(Model model) {
|
||||
try {
|
||||
String result = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", String.class);
|
||||
JsonNode node = new ObjectMapper().readTree(result);
|
||||
String name = node.get("name").asText();
|
||||
model.addAttribute("info", name);
|
||||
} catch (UserApprovalRequiredException e) {
|
||||
throw e;
|
||||
} catch (UserRedirectRequiredException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error occurred", e);
|
||||
model.addAttribute("error", e.getMessage());
|
||||
}
|
||||
JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class);
|
||||
String name = node.get("name").asText();
|
||||
model.addAttribute("info", name);
|
||||
return "reddit";
|
||||
}
|
||||
|
||||
@RequestMapping("/submit")
|
||||
public String submit(Model model, @RequestParam Map<String, String> formParams) {
|
||||
try {
|
||||
MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>();
|
||||
param.add("api_type", "json");
|
||||
param.add("kind", "link");
|
||||
param.add("resubmit", "true");
|
||||
param.add("sendreplies", "false");
|
||||
param.add("then", "comments");
|
||||
MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>();
|
||||
param.add("api_type", "json");
|
||||
param.add("kind", "link");
|
||||
param.add("resubmit", "true");
|
||||
param.add("sendreplies", "false");
|
||||
param.add("then", "comments");
|
||||
|
||||
for (Map.Entry<String, String> entry : formParams.entrySet()) {
|
||||
param.add(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
LOGGER.info("User submitting Link with these parameters: " + formParams.entrySet());
|
||||
String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, String.class);
|
||||
LOGGER.info("Full Reddit Response: " + result);
|
||||
String responseMsg = parseResponse(result);
|
||||
model.addAttribute("msg", responseMsg);
|
||||
} catch (UserApprovalRequiredException e) {
|
||||
throw e;
|
||||
} catch (UserRedirectRequiredException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error occurred", e);
|
||||
model.addAttribute("msg", e.getLocalizedMessage());
|
||||
for (Map.Entry<String, String> entry : formParams.entrySet()) {
|
||||
param.add(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
LOGGER.info("User submitting Link with these parameters: " + formParams.entrySet());
|
||||
JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class);
|
||||
LOGGER.info("Full Reddit Response: " + node.toString());
|
||||
String responseMsg = parseResponse(node);
|
||||
model.addAttribute("msg", responseMsg);
|
||||
return "submissionResponse";
|
||||
}
|
||||
|
||||
@RequestMapping("/post")
|
||||
public String showSubmissionForm(Model model) throws JsonProcessingException, IOException {
|
||||
try {
|
||||
String needsCaptchaResult = needsCaptcha();
|
||||
if (needsCaptchaResult.equalsIgnoreCase("true")) {
|
||||
String iden = getNewCaptcha();
|
||||
model.addAttribute("iden", iden);
|
||||
}
|
||||
} catch (UserApprovalRequiredException e) {
|
||||
throw e;
|
||||
} catch (UserRedirectRequiredException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error occurred", e);
|
||||
model.addAttribute("error", e.getLocalizedMessage());
|
||||
return "reddit";
|
||||
public String showSubmissionForm(Model model) {
|
||||
String needsCaptchaResult = needsCaptcha();
|
||||
if (needsCaptchaResult.equalsIgnoreCase("true")) {
|
||||
String iden = getNewCaptcha();
|
||||
model.addAttribute("iden", iden);
|
||||
}
|
||||
return "submissionForm";
|
||||
}
|
||||
|
||||
// === private
|
||||
|
||||
private List<String> getSubreddit() throws JsonProcessingException, IOException {
|
||||
public List<String> getSubreddit() throws JsonProcessingException, IOException {
|
||||
String result = redditRestTemplate.getForObject("https://oauth.reddit.com/subreddits/popular?limit=50", String.class);
|
||||
JsonNode node = new ObjectMapper().readTree(result);
|
||||
node = node.get("data").get("children");
|
||||
|
@ -119,21 +84,16 @@ public class RedditController {
|
|||
}
|
||||
|
||||
private String getNewCaptcha() {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity req = new HttpEntity(headers);
|
||||
|
||||
Map<String, String> param = new HashMap<String, String>();
|
||||
param.put("api_type", "json");
|
||||
|
||||
ResponseEntity<String> result = redditRestTemplate.postForEntity("https://oauth.reddit.com/api/new_captcha", req, String.class, param);
|
||||
String[] split = result.getBody().split("\"");
|
||||
String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/new_captcha", param, String.class, param);
|
||||
String[] split = result.split("\"");
|
||||
return split[split.length - 2];
|
||||
}
|
||||
|
||||
private String parseResponse(String responseBody) throws JsonProcessingException, IOException {
|
||||
private String parseResponse(JsonNode node) {
|
||||
String result = "";
|
||||
JsonNode node = new ObjectMapper().readTree(responseBody);
|
||||
JsonNode errorNode = node.get("json").get("errors").get(0);
|
||||
if (errorNode != null) {
|
||||
for (JsonNode child : errorNode) {
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package org.baeldung.web;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException;
|
||||
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
|
||||
import org.springframework.web.bind.annotation.ControllerAdvice;
|
||||
import org.springframework.web.bind.annotation.ExceptionHandler;
|
||||
import org.springframework.web.context.request.WebRequest;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler;
|
||||
|
||||
@ControllerAdvice
|
||||
public class RestExceptionHandler extends ResponseEntityExceptionHandler implements Serializable {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private static final long serialVersionUID = -3861125729653781371L;
|
||||
|
||||
public RestExceptionHandler() {
|
||||
super();
|
||||
}
|
||||
|
||||
// API
|
||||
|
||||
// 500
|
||||
@ExceptionHandler({ UserApprovalRequiredException.class, UserRedirectRequiredException.class })
|
||||
public ResponseEntity<Object> handleRedirect(final RuntimeException ex, final WebRequest request) {
|
||||
logger.error("500 Status Code", ex);
|
||||
throw ex;
|
||||
}
|
||||
|
||||
@ExceptionHandler({ Exception.class })
|
||||
public ResponseEntity<Object> handleInternal(final RuntimeException ex, final WebRequest request) {
|
||||
logger.info(request.getHeader("x-ratelimit-remaining"));
|
||||
logger.error("500 Status Code", ex);
|
||||
String response = "Error Occurred : " + ex.getMessage();
|
||||
return handleExceptionInternal(ex, response, new HttpHeaders(), HttpStatus.INTERNAL_SERVER_ERROR, request);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue