Merge pull request #155 from Doha2012/master

modify schedule reddit
This commit is contained in:
Eugen 2015-03-04 00:01:28 +02:00
commit 4c6e0c8a1f
12 changed files with 214 additions and 76 deletions

View File

@ -25,11 +25,11 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
private static final long serialVersionUID = 3822611002661972274L; private static final long serialVersionUID = 3822611002661972274L;
private StateKeyGenerator stateKeyGenerator = new DefaultStateKeyGenerator(); private final StateKeyGenerator stateKeyGenerator = new DefaultStateKeyGenerator();
@Override @Override
public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException { public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request) throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException, OAuth2AccessDeniedException {
AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details; final AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
if (request.getAuthorizationCode() == null) { if (request.getAuthorizationCode() == null) {
if (request.getStateKey() == null) { if (request.getStateKey() == null) {
@ -41,16 +41,16 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
} }
private HttpHeaders getHeadersForTokenRequest(AccessTokenRequest request) { private HttpHeaders getHeadersForTokenRequest(AccessTokenRequest request) {
HttpHeaders headers = new HttpHeaders(); final HttpHeaders headers = new HttpHeaders();
return headers; return headers;
} }
private MultiValueMap<String, String> getParametersForTokenRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) { private MultiValueMap<String, String> getParametersForTokenRequest(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) {
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>(); final MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
form.set("grant_type", "authorization_code"); form.set("grant_type", "authorization_code");
form.set("code", request.getAuthorizationCode()); form.set("code", request.getAuthorizationCode());
Object preservedState = request.getPreservedState(); final Object preservedState = request.getPreservedState();
if (request.getStateKey() != null) { if (request.getStateKey() != null) {
if (preservedState == null) { if (preservedState == null) {
throw new InvalidRequestException("Possible CSRF detected - state parameter was present but no state could be found"); throw new InvalidRequestException("Possible CSRF detected - state parameter was present but no state could be found");
@ -64,7 +64,7 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
redirectUri = resource.getRedirectUri(request); redirectUri = resource.getRedirectUri(request);
} }
if (redirectUri != null && !"NONE".equals(redirectUri)) { if ((redirectUri != null) && !"NONE".equals(redirectUri)) {
form.set("redirect_uri", redirectUri); form.set("redirect_uri", redirectUri);
} }
@ -72,24 +72,23 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
} }
private UserRedirectRequiredException getRedirectForAuthorization(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) { private UserRedirectRequiredException getRedirectForAuthorization(AuthorizationCodeResourceDetails resource, AccessTokenRequest request) {
TreeMap<String, String> requestParameters = new TreeMap<String, String>(); final TreeMap<String, String> requestParameters = new TreeMap<String, String>();
requestParameters.put("response_type", "code"); requestParameters.put("response_type", "code");
requestParameters.put("client_id", resource.getClientId()); requestParameters.put("client_id", resource.getClientId());
requestParameters.put("duration", "permanent"); requestParameters.put("duration", "permanent");
System.out.println("===== at private message redirect ===");
String redirectUri = resource.getRedirectUri(request); final String redirectUri = resource.getRedirectUri(request);
if (redirectUri != null) { if (redirectUri != null) {
requestParameters.put("redirect_uri", redirectUri); requestParameters.put("redirect_uri", redirectUri);
} }
if (resource.isScoped()) { if (resource.isScoped()) {
StringBuilder builder = new StringBuilder(); final StringBuilder builder = new StringBuilder();
List<String> scope = resource.getScope(); final List<String> scope = resource.getScope();
if (scope != null) { if (scope != null) {
Iterator<String> scopeIt = scope.iterator(); final Iterator<String> scopeIt = scope.iterator();
while (scopeIt.hasNext()) { while (scopeIt.hasNext()) {
builder.append(scopeIt.next()); builder.append(scopeIt.next());
if (scopeIt.hasNext()) { if (scopeIt.hasNext()) {
@ -101,9 +100,9 @@ public class MyAuthorizationCodeAccessTokenProvider extends AuthorizationCodeAcc
requestParameters.put("scope", builder.toString()); requestParameters.put("scope", builder.toString());
} }
UserRedirectRequiredException redirectException = new UserRedirectRequiredException(resource.getUserAuthorizationUri(), requestParameters); final UserRedirectRequiredException redirectException = new UserRedirectRequiredException(resource.getUserAuthorizationUri(), requestParameters);
String stateKey = stateKeyGenerator.generateKey(resource); final String stateKey = stateKeyGenerator.generateKey(resource);
redirectException.setStateKey(stateKey); redirectException.setStateKey(stateKey);
request.setStateKey(stateKey); request.setStateKey(stateKey);
redirectException.setStateToPreserve(redirectUri); redirectException.setStateToPreserve(redirectUri);

View File

@ -3,16 +3,20 @@ package org.baeldung.config;
import javax.servlet.http.HttpSessionEvent; import javax.servlet.http.HttpSessionEvent;
import javax.servlet.http.HttpSessionListener; import javax.servlet.http.HttpSessionListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SessionListener implements HttpSessionListener { public class SessionListener implements HttpSessionListener {
private final Logger logger = LoggerFactory.getLogger(getClass());
@Override @Override
public void sessionCreated(HttpSessionEvent event) { public void sessionCreated(HttpSessionEvent event) {
System.out.println("==== Session is created ===="); logger.info("==== Session is created ====");
event.getSession().setMaxInactiveInterval(30 * 60); event.getSession().setMaxInactiveInterval(30 * 60);
} }
@Override @Override
public void sessionDestroyed(HttpSessionEvent event) { public void sessionDestroyed(HttpSessionEvent event) {
System.out.println("==== Session is destroyed ===="); logger.info("==== Session is destroyed ====");
} }
} }

View File

@ -9,7 +9,7 @@ import org.springframework.data.jpa.repository.JpaRepository;
public interface PostRepository extends JpaRepository<Post, Long> { public interface PostRepository extends JpaRepository<Post, Long> {
public List<Post> findBySubmissionDateBeforeAndIsSent(Date date, boolean isSent); List<Post> findBySubmissionDateBeforeAndIsSent(Date date, boolean isSent);
public List<Post> findByUser(User user); List<Post> findByUser(User user);
} }

View File

@ -4,7 +4,8 @@ import org.baeldung.persistence.model.User;
import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.JpaRepository;
public interface UserRepository extends JpaRepository<User, Long> { public interface UserRepository extends JpaRepository<User, Long> {
public User findByUsername(String username);
public User findByAccessToken(String token); User findByUsername(String username);
User findByAccessToken(String token);
} }

View File

@ -100,4 +100,9 @@ public class Post {
this.submissionResponse = submissionResponse; this.submissionResponse = submissionResponse;
} }
@Override
public String toString() {
return "Post [title=" + title + ", subreddit=" + subreddit + ", url=" + url + ", submissionDate=" + submissionDate + ", user=" + user + "]";
}
} }

View File

@ -83,22 +83,30 @@ public class User {
public int hashCode() { public int hashCode() {
final int prime = 31; final int prime = 31;
int result = 1; int result = 1;
result = prime * result + ((username == null) ? 0 : username.hashCode()); result = (prime * result) + ((username == null) ? 0 : username.hashCode());
return result; return result;
} }
@Override @Override
public boolean equals(final Object obj) { public boolean equals(final Object obj) {
if (this == obj) if (this == obj) {
return true; return true;
if (obj == null) }
if (obj == null) {
return false; return false;
if (getClass() != obj.getClass()) }
if (getClass() != obj.getClass()) {
return false; return false;
}
final User user = (User) obj; final User user = (User) obj;
if (!username.equals(user.username)) if (!username.equals(user.username)) {
return false; return false;
}
return true; return true;
} }
@Override
public String toString() {
return "User [username=" + username + "]";
}
} }

View File

@ -0,0 +1,19 @@
package org.baeldung.reddit.util;
/**
* Specified by Reddit API at http://www.reddit.com/dev/api#POST_api_submit
*/
public final class RedditApiConstants {
public static final String API_TYPE = "api_type";
public static final String URL = "url";
public static final String SR = "sr";
public static final String TITLE = "title";
public static final String THEN = "then";
public static final String SENDREPLIES = "sendreplies";
public static final String RESUBMIT = "resubmit";
public static final String KIND = "kind";
private RedditApiConstants() {
throw new AssertionError();
}
}

View File

@ -11,6 +11,7 @@ import org.baeldung.persistence.dao.PostRepository;
import org.baeldung.persistence.dao.UserRepository; import org.baeldung.persistence.dao.UserRepository;
import org.baeldung.persistence.model.Post; import org.baeldung.persistence.model.Post;
import org.baeldung.persistence.model.User; import org.baeldung.persistence.model.User;
import org.baeldung.reddit.util.RedditApiConstants;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -43,8 +44,8 @@ public class RedditController {
@RequestMapping("/info") @RequestMapping("/info")
public final String getInfo(Model model) { public final String getInfo(Model model) {
JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class); final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class);
String name = node.get("name").asText(); final String name = node.get("name").asText();
addUser(name, redditRestTemplate.getAccessToken()); addUser(name, redditRestTemplate.getAccessToken());
model.addAttribute("info", name); model.addAttribute("info", name);
return "reddit"; return "reddit";
@ -52,30 +53,30 @@ public class RedditController {
@RequestMapping("/submit") @RequestMapping("/submit")
public final String submit(Model model, @RequestParam Map<String, String> formParams) { public final String submit(Model model, @RequestParam Map<String, String> formParams) {
MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>(); final MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>();
param.add("api_type", "json"); param.add(RedditApiConstants.API_TYPE, "json");
param.add("kind", "link"); param.add(RedditApiConstants.KIND, "link");
param.add("resubmit", "true"); param.add(RedditApiConstants.RESUBMIT, "true");
param.add("sendreplies", "false"); param.add(RedditApiConstants.SENDREPLIES, "false");
param.add("then", "comments"); param.add(RedditApiConstants.THEN, "comments");
for (Map.Entry<String, String> entry : formParams.entrySet()) { for (final Map.Entry<String, String> entry : formParams.entrySet()) {
param.add(entry.getKey(), entry.getValue()); param.add(entry.getKey(), entry.getValue());
} }
logger.info("User submitting Link with these parameters: " + formParams.entrySet()); logger.info("User submitting Link with these parameters: " + formParams.entrySet());
JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class); final JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class);
logger.info("Full Reddit Response: " + node.toString()); logger.info("Full Reddit Response: " + node.toString());
String responseMsg = parseResponse(node); final String responseMsg = parseResponse(node);
model.addAttribute("msg", responseMsg); model.addAttribute("msg", responseMsg);
return "submissionResponse"; return "submissionResponse";
} }
@RequestMapping("/post") @RequestMapping("/post")
public final String showSubmissionForm(Model model) { public final String showSubmissionForm(Model model) {
String needsCaptchaResult = needsCaptcha(); final String needsCaptchaResult = needsCaptcha();
if (needsCaptchaResult.equalsIgnoreCase("true")) { if (needsCaptchaResult.equalsIgnoreCase("true")) {
String iden = getNewCaptcha(); final String iden = getNewCaptcha();
model.addAttribute("iden", iden); model.addAttribute("iden", iden);
} }
return "submissionForm"; return "submissionForm";
@ -83,7 +84,7 @@ public class RedditController {
@RequestMapping("/postSchedule") @RequestMapping("/postSchedule")
public final String showSchedulePostForm(Model model) { public final String showSchedulePostForm(Model model) {
String needsCaptchaResult = needsCaptcha(); final String needsCaptchaResult = needsCaptcha();
if (needsCaptchaResult.equalsIgnoreCase("true")) { if (needsCaptchaResult.equalsIgnoreCase("true")) {
model.addAttribute("msg", "Sorry, You do not have enought karma"); model.addAttribute("msg", "Sorry, You do not have enought karma");
return "submissionResponse"; return "submissionResponse";
@ -94,8 +95,8 @@ public class RedditController {
@RequestMapping("/schedule") @RequestMapping("/schedule")
public final String schedule(Model model, @RequestParam Map<String, String> formParams) throws ParseException { public final String schedule(Model model, @RequestParam Map<String, String> formParams) throws ParseException {
logger.info("User scheduling Post with these parameters: " + formParams.entrySet()); logger.info("User scheduling Post with these parameters: " + formParams.entrySet());
User user = userReopsitory.findByAccessToken(redditRestTemplate.getAccessToken().getValue()); final User user = userReopsitory.findByAccessToken(redditRestTemplate.getAccessToken().getValue());
Post post = new Post(); final Post post = new Post();
post.setUser(user); post.setUser(user);
post.setSent(false); post.setSent(false);
post.setTitle(formParams.get("title")); post.setTitle(formParams.get("title"));
@ -108,15 +109,15 @@ public class RedditController {
return "submissionResponse"; return "submissionResponse";
} }
postReopsitory.save(post); postReopsitory.save(post);
List<Post> posts = postReopsitory.findByUser(user); final List<Post> posts = postReopsitory.findByUser(user);
model.addAttribute("posts", posts); model.addAttribute("posts", posts);
return "postListView"; return "postListView";
} }
@RequestMapping("/posts") @RequestMapping("/posts")
public final String getScheduledPosts(Model model) { public final String getScheduledPosts(Model model) {
User user = userReopsitory.findByAccessToken(redditRestTemplate.getAccessToken().getValue()); final User user = userReopsitory.findByAccessToken(redditRestTemplate.getAccessToken().getValue());
List<Post> posts = postReopsitory.findByUser(user); final List<Post> posts = postReopsitory.findByUser(user);
model.addAttribute("posts", posts); model.addAttribute("posts", posts);
return "postListView"; return "postListView";
} }
@ -124,32 +125,33 @@ public class RedditController {
// === private // === private
private final String needsCaptcha() { private final String needsCaptcha() {
String result = redditRestTemplate.getForObject("https://oauth.reddit.com/api/needs_captcha.json", String.class); final String result = redditRestTemplate.getForObject("https://oauth.reddit.com/api/needs_captcha.json", String.class);
return result; return result;
} }
private final String getNewCaptcha() { private final String getNewCaptcha() {
Map<String, String> param = new HashMap<String, String>(); final Map<String, String> param = new HashMap<String, String>();
param.put("api_type", "json"); param.put("api_type", "json");
String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/new_captcha", param, String.class, param); final String result = redditRestTemplate.postForObject("https://oauth.reddit.com/api/new_captcha", param, String.class, param);
String[] split = result.split("\""); final String[] split = result.split("\"");
return split[split.length - 2]; return split[split.length - 2];
} }
private final String parseResponse(JsonNode node) { private final String parseResponse(JsonNode node) {
String result = ""; String result = "";
JsonNode errorNode = node.get("json").get("errors").get(0); final JsonNode errorNode = node.get("json").get("errors").get(0);
if (errorNode != null) { if (errorNode != null) {
for (JsonNode child : errorNode) { for (final JsonNode child : errorNode) {
result = result + child.toString().replaceAll("\"|null", "") + "<br>"; result = result + child.toString().replaceAll("\"|null", "") + "<br>";
} }
return result; return result;
} else { } else {
if (node.get("json").get("data") != null && node.get("json").get("data").get("url") != null) if ((node.get("json").get("data") != null) && (node.get("json").get("data").get("url") != null)) {
return "Post submitted successfully <a href=\"" + node.get("json").get("data").get("url").asText() + "\"> check it out </a>"; return "Post submitted successfully <a href=\"" + node.get("json").get("data").get("url").asText() + "\"> check it out </a>";
else } else {
return "Error Occurred"; return "Error Occurred";
}
} }
} }

View File

@ -26,7 +26,7 @@ public class RestExceptionHandler extends ResponseEntityExceptionHandler impleme
// 500 // 500
@ExceptionHandler({ UserApprovalRequiredException.class, UserRedirectRequiredException.class }) @ExceptionHandler({ UserApprovalRequiredException.class, UserRedirectRequiredException.class })
public ResponseEntity<Object> handleRedirect(final RuntimeException ex, final WebRequest request) { public ResponseEntity<Object> handleRedirect(final RuntimeException ex, final WebRequest request) {
logger.error("500 Status Code", ex); logger.info(ex.getLocalizedMessage());
throw ex; throw ex;
} }
@ -34,7 +34,7 @@ public class RestExceptionHandler extends ResponseEntityExceptionHandler impleme
public ResponseEntity<Object> handleInternal(final RuntimeException ex, final WebRequest request) { public ResponseEntity<Object> handleInternal(final RuntimeException ex, final WebRequest request) {
logger.info(request.getHeader("x-ratelimit-remaining")); logger.info(request.getHeader("x-ratelimit-remaining"));
logger.error("500 Status Code", ex); logger.error("500 Status Code", ex);
String response = "Error Occurred : " + ex.getMessage(); final String response = "Error Occurred : " + ex.getMessage();
return handleExceptionInternal(ex, response, new HttpHeaders(), HttpStatus.INTERNAL_SERVER_ERROR, request); return handleExceptionInternal(ex, response, new HttpHeaders(), HttpStatus.INTERNAL_SERVER_ERROR, request);
} }
} }

View File

@ -7,6 +7,7 @@ import java.util.List;
import org.baeldung.persistence.dao.PostRepository; import org.baeldung.persistence.dao.PostRepository;
import org.baeldung.persistence.model.Post; import org.baeldung.persistence.model.Post;
import org.baeldung.persistence.model.User; import org.baeldung.persistence.model.User;
import org.baeldung.reddit.util.RedditApiConstants;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -34,50 +35,49 @@ public class ScheduledTasks {
@Scheduled(fixedRate = 1 * 60 * 1000) @Scheduled(fixedRate = 1 * 60 * 1000)
public void reportCurrentTime() { public void reportCurrentTime() {
List<Post> posts = postReopsitory.findBySubmissionDateBeforeAndIsSent(new Date(), false); final List<Post> posts = postReopsitory.findBySubmissionDateBeforeAndIsSent(new Date(), false);
logger.info(posts.size() + " Posts in the queue."); logger.info(posts.size() + " Posts in the queue.");
for (Post post : posts) { for (final Post post : posts) {
submitPost(post); submitPost(post);
} }
} }
private void submitPost(Post post) { private void submitPost(Post post) {
try { try {
User user = post.getUser(); final User user = post.getUser();
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(user.getAccessToken()); final DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(user.getAccessToken());
token.setRefreshToken(new DefaultOAuth2RefreshToken((user.getRefreshToken()))); token.setRefreshToken(new DefaultOAuth2RefreshToken((user.getRefreshToken())));
token.setExpiration(user.getTokenExpiration()); token.setExpiration(user.getTokenExpiration());
redditRestTemplate.getOAuth2ClientContext().setAccessToken(token); redditRestTemplate.getOAuth2ClientContext().setAccessToken(token);
// //
UsernamePasswordAuthenticationToken userAuthToken = new UsernamePasswordAuthenticationToken(user.getUsername(), token.getValue(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER"))); final UsernamePasswordAuthenticationToken userAuthToken = new UsernamePasswordAuthenticationToken(user.getUsername(), token.getValue(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER")));
SecurityContextHolder.getContext().setAuthentication(userAuthToken); SecurityContextHolder.getContext().setAuthentication(userAuthToken);
// //
final MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>();
MultiValueMap<String, String> param = new LinkedMultiValueMap<String, String>(); param.add(RedditApiConstants.TITLE, post.getTitle());
param.add("api_type", "json"); param.add(RedditApiConstants.SR, post.getSubreddit());
param.add("kind", "link"); param.add(RedditApiConstants.URL, post.getUrl());
param.add("resubmit", "true"); param.add(RedditApiConstants.API_TYPE, "json");
param.add("sendreplies", "false"); param.add(RedditApiConstants.KIND, "link");
param.add("then", "comments"); param.add(RedditApiConstants.RESUBMIT, "true");
param.add("title", post.getTitle()); param.add(RedditApiConstants.SENDREPLIES, "false");
param.add("sr", post.getSubreddit()); param.add(RedditApiConstants.THEN, "comments");
param.add("url", post.getUrl());
logger.info("Submit link with these parameters: " + param.entrySet()); logger.info("Submit link with these parameters: " + param.entrySet());
JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class); final JsonNode node = redditRestTemplate.postForObject("https://oauth.reddit.com/api/submit", param, JsonNode.class);
JsonNode errorNode = node.get("json").get("errors").get(0); final JsonNode errorNode = node.get("json").get("errors").get(0);
if (errorNode == null) { if (errorNode == null) {
post.setSent(true); post.setSent(true);
post.setSubmissionResponse("Successfully sent"); post.setSubmissionResponse("Successfully sent");
postReopsitory.save(post); postReopsitory.save(post);
logger.info("Successfully sent"); logger.info("Successfully sent " + post.toString());
} else { } else {
post.setSubmissionResponse(errorNode.toString()); post.setSubmissionResponse(errorNode.toString());
postReopsitory.save(post); postReopsitory.save(post);
logger.info("Error occurred: " + errorNode.toString()); logger.info("Error occurred: " + errorNode.toString() + "while submitting post " + post.toString());
} }
} catch (Exception e) { } catch (final Exception e) {
logger.error("Error occurred", e); logger.error("Error occurred while submitting post " + post.toString(), e);
} }
} }

View File

@ -1,6 +1,6 @@
################### DataSource Configuration ########################## ################### DataSource Configuration ##########################
jdbc.driverClassName=com.mysql.jdbc.Driver jdbc.driverClassName=com.mysql.jdbc.Driver
jdbc.url=jdbc:mysql://localhost:3306/oauth_reddit?createDatabaseIfNotExist=true jdbc.url=jdbc:mysql://localhost:3606/oauth_reddit?createDatabaseIfNotExist=true
jdbc.user=tutorialuser jdbc.user=tutorialuser
jdbc.pass=tutorialmy5ql jdbc.pass=tutorialmy5ql
init-db=false init-db=false

View File

@ -0,0 +1,100 @@
package org.baeldung.persistence;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.collection.IsIn.isIn;
import static org.hamcrest.core.IsNot.not;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.List;
import org.baeldung.config.PersistenceJPAConfig;
import org.baeldung.persistence.dao.PostRepository;
import org.baeldung.persistence.dao.UserRepository;
import org.baeldung.persistence.model.Post;
import org.baeldung.persistence.model.User;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.transaction.TransactionConfiguration;
import org.springframework.transaction.annotation.Transactional;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = { PersistenceJPAConfig.class })
@Transactional
@TransactionConfiguration
public class PersistenceJPATest {
@Autowired
private PostRepository postRepository;
@Autowired
private UserRepository userRepository;
private Post alreadySentPost, notSentYetOld, notSentYet;
private User userJohn, userTom;
private static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm");
@Before
public void init() throws ParseException {
userJohn = new User();
userJohn.setUsername("John");
userRepository.save(userJohn);
userTom = new User();
userTom.setUsername("Tom");
userRepository.save(userTom);
alreadySentPost = new Post();
alreadySentPost.setTitle("First post title");
alreadySentPost.setSent(true);
alreadySentPost.setSubmissionDate(dateFormat.parse("2015-03-03 10:30"));
alreadySentPost.setUser(userJohn);
postRepository.save(alreadySentPost);
notSentYetOld = new Post();
notSentYetOld.setTitle("Second post title");
notSentYetOld.setSent(false);
notSentYetOld.setSubmissionDate(dateFormat.parse("2015-03-03 11:00"));
notSentYetOld.setUser(userTom);
postRepository.save(notSentYetOld);
notSentYet = new Post();
notSentYet.setTitle("Second post title");
notSentYet.setSent(false);
notSentYet.setSubmissionDate(dateFormat.parse("2015-03-03 11:30"));
notSentYet.setUser(userJohn);
postRepository.save(notSentYet);
}
@Test
public void whenGettingListOfSentPosts_thenCorrect() throws ParseException {
final List<Post> results = postRepository.findBySubmissionDateBeforeAndIsSent(dateFormat.parse("2015-03-03 11:50"), true);
assertThat(alreadySentPost, isIn(results));
assertThat(notSentYet, not(isIn(results)));
assertThat(notSentYetOld, not(isIn(results)));
}
@Test
public void whenGettingListOfOldPosts_thenCorrect() throws ParseException {
final List<Post> results = postRepository.findBySubmissionDateBeforeAndIsSent(dateFormat.parse("2015-03-03 11:01"), false);
assertThat(notSentYetOld, isIn(results));
assertThat(notSentYet, not(isIn(results)));
assertThat(alreadySentPost, not(isIn(results)));
}
@Test
public void whenGettingListOfSpecificuser_thenCorrect() throws ParseException {
final List<Post> results = postRepository.findByUser(userTom);
assertThat(notSentYetOld, isIn(results));
assertThat(notSentYet, not(isIn(results)));
assertThat(alreadySentPost, not(isIn(results)));
}
}