spring oauth auto login

This commit is contained in:
DOHA 2015-04-26 22:09:23 +02:00
parent 3ffe8425f3
commit ccad39e3b3
8 changed files with 122 additions and 14 deletions

View File

@ -31,8 +31,12 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
.authorizeRequests()
.antMatchers("/home.html","/post","/postSchedule","/posts").hasRole("USER")
.and()
.httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint());
.httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint())
.and()
.logout()
.deleteCookies("JSESSIONID","CustomRememberMe")
.logoutUrl("/logout")
.logoutSuccessUrl("/");
// @formatter:on
}

View File

@ -13,7 +13,7 @@ public class SessionListener implements HttpSessionListener {
@Override
public void sessionCreated(HttpSessionEvent event) {
logger.info("==== Session is created ====");
event.getSession().setMaxInactiveInterval(30 * 60);
event.getSession().setMaxInactiveInterval(1 * 60);
event.getSession().setAttribute("PREDICTION_FEATURE", MyFeatures.PREDICTION_FEATURE);
}

View File

@ -5,6 +5,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.baeldung.persistence.service.RedditTokenService;
import org.baeldung.reddit.classifier.RedditClassifier;
import org.baeldung.reddit.util.UserAgentInterceptor;
import org.baeldung.web.schedule.ScheduledTasks;
@ -16,8 +17,6 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.PropertySource;
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@ -86,9 +85,9 @@ public class WebConfig extends WebMvcConfigurerAdapter {
@Bean
public RedditClassifier redditClassifier() throws IOException {
final Resource file = new ClassPathResource("data.csv");
// final Resource file = new ClassPathResource("data.csv");
final RedditClassifier redditClassifier = new RedditClassifier();
redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
// redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
return redditClassifier;
}
@ -131,13 +130,14 @@ public class WebConfig extends WebMvcConfigurerAdapter {
}
@Bean
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) {
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext, RedditTokenService redditTokenService) {
final OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext);
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
list.add(new UserAgentInterceptor());
template.setInterceptors(list);
final AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
final AccessTokenProviderChain accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
new ClientCredentialsAccessTokenProvider()));
accessTokenProvider.setClientTokenServices(redditTokenService);
template.setAccessTokenProvider(accessTokenProvider);
return template;
}

View File

@ -8,4 +8,6 @@ public interface UserRepository extends JpaRepository<User, Long> {
User findByUsername(String username);
User findByAccessToken(String token);
User findByRememberMeToken(String token);
}

View File

@ -26,6 +26,8 @@ public class User {
private boolean needCaptcha;
private String rememberMeToken;
public User() {
super();
}
@ -80,6 +82,14 @@ public class User {
//
public String getRememberMeToken() {
return rememberMeToken;
}
public void setRememberMeToken(String rememberMeToken) {
this.rememberMeToken = rememberMeToken;
}
@Override
public int hashCode() {
final int prime = 31;

View File

@ -0,0 +1,55 @@
package org.baeldung.persistence.service;
import org.baeldung.persistence.dao.UserRepository;
import org.baeldung.persistence.model.User;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.token.ClientTokenServices;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.stereotype.Component;
@Component
public class RedditTokenService implements ClientTokenServices {
@Autowired
private UserRepository userReopsitory;
private final Logger logger = LoggerFactory.getLogger(getClass());
public RedditTokenService() {
super();
}
@Override
public OAuth2AccessToken getAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
logger.info("reddit ==== getAccessToken");
final User user = (User) authentication.getPrincipal();
final DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(user.getAccessToken());
token.setRefreshToken(new DefaultOAuth2RefreshToken((user.getRefreshToken())));
token.setExpiration(user.getTokenExpiration());
return token;
}
@Override
public void saveAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication, OAuth2AccessToken accessToken) {
logger.info("reddit ==== saveAccessToken");
final User user = (User) authentication.getPrincipal();
user.setAccessToken(accessToken.getValue());
if (accessToken.getRefreshToken() != null) {
user.setRefreshToken(accessToken.getRefreshToken().getValue());
}
user.setTokenExpiration(accessToken.getExpiration());
userReopsitory.save(user);
}
@Override
public void removeAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
logger.info("reddit ==== removeAccessToken");
}
}

View File

@ -8,6 +8,11 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.lang.RandomStringUtils;
import org.baeldung.persistence.dao.PostRepository;
import org.baeldung.persistence.dao.UserRepository;
import org.baeldung.persistence.model.Post;
@ -27,6 +32,7 @@ import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.annotation.CookieValue;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
@ -43,6 +49,7 @@ public class RedditController {
private static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm");
private final SimpleDateFormat dfHour = new SimpleDateFormat("HH");
public static final String REMEMBER_ME_COOKIE = "CustomRememberMe";
@Autowired
private OAuth2RestTemplate redditRestTemplate;
@ -57,9 +64,11 @@ public class RedditController {
private RedditClassifier redditClassifier;
@RequestMapping("/login")
public final String redditLogin() {
public final String redditLogin(@CookieValue(value = REMEMBER_ME_COOKIE, required = false) String rememberMe, HttpServletRequest request, HttpServletResponse response) {
if (!canAutoLogin(rememberMe)) {
final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class);
loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken());
loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken(), response);
}
return "redirect:home.html";
}
@ -221,7 +230,7 @@ public class RedditController {
}
}
private final void loadAuthentication(final String name, final OAuth2AccessToken token) {
private void loadAuthentication(final String name, final OAuth2AccessToken token, HttpServletResponse response) {
User user = userReopsitory.findByUsername(name);
if (user == null) {
user = new User();
@ -239,8 +248,35 @@ public class RedditController {
user.setTokenExpiration(token.getExpiration());
userReopsitory.save(user);
generateRememberMeToken(user, response);
final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, token.getValue(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER")));
SecurityContextHolder.getContext().setAuthentication(auth);
}
private void generateRememberMeToken(User user, HttpServletResponse response) {
String rememberMe = RandomStringUtils.randomAlphanumeric(30);
while (userReopsitory.findByRememberMeToken(rememberMe) != null) {
rememberMe = RandomStringUtils.randomAlphanumeric(30);
}
user.setRememberMeToken(rememberMe);
userReopsitory.save(user);
final Cookie c = new Cookie(REMEMBER_ME_COOKIE, rememberMe);
c.setMaxAge(1209600);
response.addCookie(c);
}
private boolean canAutoLogin(String rememberMeToken) {
if (rememberMeToken != null) {
final User user = userReopsitory.findByRememberMeToken(rememberMeToken);
if (user != null) {
logger.info("Auto Login successfully");
final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, user.getAccessToken(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER")));
SecurityContextHolder.getContext().setAuthentication(auth);
return true;
}
}
return false;
}
}

View File

@ -16,6 +16,7 @@
<p class="navbar-text navbar-right">Logged in as
<b><sec:authentication property="principal.username" /></b>&nbsp;&nbsp;&nbsp;
<a href="logout">Logout</a>&nbsp;&nbsp;&nbsp;
</p>
<div class="collapse navbar-collapse" id="bs-example-navbar-collapse-1">