diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/SecurityConfig.java b/spring-security-oauth/src/main/java/org/baeldung/config/SecurityConfig.java index d120a2d775..cdfccb99e6 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/config/SecurityConfig.java +++ b/spring-security-oauth/src/main/java/org/baeldung/config/SecurityConfig.java @@ -31,8 +31,12 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { .authorizeRequests() .antMatchers("/home.html","/post","/postSchedule","/posts").hasRole("USER") .and() - .httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint()); - + .httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint()) + .and() + .logout() + .deleteCookies("JSESSIONID","CustomRememberMe") + .logoutUrl("/logout") + .logoutSuccessUrl("/"); // @formatter:on } diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/SessionListener.java b/spring-security-oauth/src/main/java/org/baeldung/config/SessionListener.java index 96cf6058e4..53b68cbc1c 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/config/SessionListener.java +++ b/spring-security-oauth/src/main/java/org/baeldung/config/SessionListener.java @@ -13,7 +13,7 @@ public class SessionListener implements HttpSessionListener { @Override public void sessionCreated(HttpSessionEvent event) { logger.info("==== Session is created ===="); - event.getSession().setMaxInactiveInterval(30 * 60); + event.getSession().setMaxInactiveInterval(1 * 60); event.getSession().setAttribute("PREDICTION_FEATURE", MyFeatures.PREDICTION_FEATURE); } diff --git a/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java b/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java index ddddc781c5..7b03b1cc8a 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java +++ b/spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.baeldung.persistence.service.RedditTokenService; import org.baeldung.reddit.classifier.RedditClassifier; import org.baeldung.reddit.util.UserAgentInterceptor; import org.baeldung.web.schedule.ScheduledTasks; @@ -16,8 +17,6 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.PropertySource; import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; import org.springframework.core.env.Environment; -import org.springframework.core.io.ClassPathResource; -import org.springframework.core.io.Resource; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; @@ -86,9 +85,9 @@ public class WebConfig extends WebMvcConfigurerAdapter { @Bean public RedditClassifier redditClassifier() throws IOException { - final Resource file = new ClassPathResource("data.csv"); + // final Resource file = new ClassPathResource("data.csv"); final RedditClassifier redditClassifier = new RedditClassifier(); - redditClassifier.trainClassifier(file.getFile().getAbsolutePath()); + // redditClassifier.trainClassifier(file.getFile().getAbsolutePath()); return redditClassifier; } @@ -131,15 +130,16 @@ public class WebConfig extends WebMvcConfigurerAdapter { } @Bean - public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) { + public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext, RedditTokenService redditTokenService) { final OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext); final List list = new ArrayList(); list.add(new UserAgentInterceptor()); template.setInterceptors(list); - final AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays. asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), + final AccessTokenProviderChain accessTokenProvider = new AccessTokenProviderChain(Arrays. asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider())); + accessTokenProvider.setClientTokenServices(redditTokenService); template.setAccessTokenProvider(accessTokenProvider); return template; } } -} \ No newline at end of file +} diff --git a/spring-security-oauth/src/main/java/org/baeldung/persistence/dao/UserRepository.java b/spring-security-oauth/src/main/java/org/baeldung/persistence/dao/UserRepository.java index ec2221c73d..4b045aa37c 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/persistence/dao/UserRepository.java +++ b/spring-security-oauth/src/main/java/org/baeldung/persistence/dao/UserRepository.java @@ -8,4 +8,6 @@ public interface UserRepository extends JpaRepository { User findByUsername(String username); User findByAccessToken(String token); + + User findByRememberMeToken(String token); } \ No newline at end of file diff --git a/spring-security-oauth/src/main/java/org/baeldung/persistence/model/User.java b/spring-security-oauth/src/main/java/org/baeldung/persistence/model/User.java index e3b5553922..d47fe067d2 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/persistence/model/User.java +++ b/spring-security-oauth/src/main/java/org/baeldung/persistence/model/User.java @@ -26,6 +26,8 @@ public class User { private boolean needCaptcha; + private String rememberMeToken; + public User() { super(); } @@ -80,6 +82,14 @@ public class User { // + public String getRememberMeToken() { + return rememberMeToken; + } + + public void setRememberMeToken(String rememberMeToken) { + this.rememberMeToken = rememberMeToken; + } + @Override public int hashCode() { final int prime = 31; diff --git a/spring-security-oauth/src/main/java/org/baeldung/persistence/service/RedditTokenService.java b/spring-security-oauth/src/main/java/org/baeldung/persistence/service/RedditTokenService.java new file mode 100644 index 0000000000..5a20fff1de --- /dev/null +++ b/spring-security-oauth/src/main/java/org/baeldung/persistence/service/RedditTokenService.java @@ -0,0 +1,55 @@ +package org.baeldung.persistence.service; + +import org.baeldung.persistence.dao.UserRepository; +import org.baeldung.persistence.model.User; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails; +import org.springframework.security.oauth2.client.token.ClientTokenServices; +import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; +import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken; +import org.springframework.security.oauth2.common.OAuth2AccessToken; +import org.springframework.stereotype.Component; + +@Component +public class RedditTokenService implements ClientTokenServices { + + @Autowired + private UserRepository userReopsitory; + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + public RedditTokenService() { + super(); + } + + @Override + public OAuth2AccessToken getAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) { + logger.info("reddit ==== getAccessToken"); + final User user = (User) authentication.getPrincipal(); + final DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(user.getAccessToken()); + token.setRefreshToken(new DefaultOAuth2RefreshToken((user.getRefreshToken()))); + token.setExpiration(user.getTokenExpiration()); + return token; + } + + @Override + public void saveAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication, OAuth2AccessToken accessToken) { + logger.info("reddit ==== saveAccessToken"); + final User user = (User) authentication.getPrincipal(); + user.setAccessToken(accessToken.getValue()); + if (accessToken.getRefreshToken() != null) { + user.setRefreshToken(accessToken.getRefreshToken().getValue()); + } + user.setTokenExpiration(accessToken.getExpiration()); + userReopsitory.save(user); + } + + @Override + public void removeAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) { + logger.info("reddit ==== removeAccessToken"); + } + +} diff --git a/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java b/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java index d05bdf1eab..2656debbaa 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java +++ b/spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java @@ -8,6 +8,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.lang.RandomStringUtils; import org.baeldung.persistence.dao.PostRepository; import org.baeldung.persistence.dao.UserRepository; import org.baeldung.persistence.model.Post; @@ -27,6 +32,7 @@ import org.springframework.stereotype.Controller; import org.springframework.ui.Model; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.CookieValue; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; @@ -43,6 +49,7 @@ public class RedditController { private static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm"); private final SimpleDateFormat dfHour = new SimpleDateFormat("HH"); + public static final String REMEMBER_ME_COOKIE = "CustomRememberMe"; @Autowired private OAuth2RestTemplate redditRestTemplate; @@ -57,9 +64,11 @@ public class RedditController { private RedditClassifier redditClassifier; @RequestMapping("/login") - public final String redditLogin() { - final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class); - loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken()); + public final String redditLogin(@CookieValue(value = REMEMBER_ME_COOKIE, required = false) String rememberMe, HttpServletRequest request, HttpServletResponse response) { + if (!canAutoLogin(rememberMe)) { + final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class); + loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken(), response); + } return "redirect:home.html"; } @@ -221,7 +230,7 @@ public class RedditController { } } - private final void loadAuthentication(final String name, final OAuth2AccessToken token) { + private void loadAuthentication(final String name, final OAuth2AccessToken token, HttpServletResponse response) { User user = userReopsitory.findByUsername(name); if (user == null) { user = new User(); @@ -239,8 +248,35 @@ public class RedditController { user.setTokenExpiration(token.getExpiration()); userReopsitory.save(user); + generateRememberMeToken(user, response); + final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, token.getValue(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER"))); SecurityContextHolder.getContext().setAuthentication(auth); } + private void generateRememberMeToken(User user, HttpServletResponse response) { + String rememberMe = RandomStringUtils.randomAlphanumeric(30); + while (userReopsitory.findByRememberMeToken(rememberMe) != null) { + rememberMe = RandomStringUtils.randomAlphanumeric(30); + } + user.setRememberMeToken(rememberMe); + userReopsitory.save(user); + final Cookie c = new Cookie(REMEMBER_ME_COOKIE, rememberMe); + c.setMaxAge(1209600); + response.addCookie(c); + } + + private boolean canAutoLogin(String rememberMeToken) { + if (rememberMeToken != null) { + final User user = userReopsitory.findByRememberMeToken(rememberMeToken); + if (user != null) { + logger.info("Auto Login successfully"); + final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, user.getAccessToken(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER"))); + SecurityContextHolder.getContext().setAuthentication(auth); + return true; + } + } + return false; + } + } diff --git a/spring-security-oauth/src/main/webapp/WEB-INF/jsp/home.jsp b/spring-security-oauth/src/main/webapp/WEB-INF/jsp/home.jsp index 777508b860..c70f2df7ce 100755 --- a/spring-security-oauth/src/main/webapp/WEB-INF/jsp/home.jsp +++ b/spring-security-oauth/src/main/webapp/WEB-INF/jsp/home.jsp @@ -16,6 +16,7 @@