spring oauth auto login
This commit is contained in:
parent
3ffe8425f3
commit
ccad39e3b3
|
@ -31,8 +31,12 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
|
|||
.authorizeRequests()
|
||||
.antMatchers("/home.html","/post","/postSchedule","/posts").hasRole("USER")
|
||||
.and()
|
||||
.httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint());
|
||||
|
||||
.httpBasic().authenticationEntryPoint(oauth2AuthenticationEntryPoint())
|
||||
.and()
|
||||
.logout()
|
||||
.deleteCookies("JSESSIONID","CustomRememberMe")
|
||||
.logoutUrl("/logout")
|
||||
.logoutSuccessUrl("/");
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ public class SessionListener implements HttpSessionListener {
|
|||
@Override
|
||||
public void sessionCreated(HttpSessionEvent event) {
|
||||
logger.info("==== Session is created ====");
|
||||
event.getSession().setMaxInactiveInterval(30 * 60);
|
||||
event.getSession().setMaxInactiveInterval(1 * 60);
|
||||
event.getSession().setAttribute("PREDICTION_FEATURE", MyFeatures.PREDICTION_FEATURE);
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.baeldung.persistence.service.RedditTokenService;
|
||||
import org.baeldung.reddit.classifier.RedditClassifier;
|
||||
import org.baeldung.reddit.util.UserAgentInterceptor;
|
||||
import org.baeldung.web.schedule.ScheduledTasks;
|
||||
|
@ -16,8 +17,6 @@ import org.springframework.context.annotation.Configuration;
|
|||
import org.springframework.context.annotation.PropertySource;
|
||||
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.client.ClientHttpRequestInterceptor;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
|
@ -86,9 +85,9 @@ public class WebConfig extends WebMvcConfigurerAdapter {
|
|||
|
||||
@Bean
|
||||
public RedditClassifier redditClassifier() throws IOException {
|
||||
final Resource file = new ClassPathResource("data.csv");
|
||||
// final Resource file = new ClassPathResource("data.csv");
|
||||
final RedditClassifier redditClassifier = new RedditClassifier();
|
||||
redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
|
||||
// redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
|
||||
return redditClassifier;
|
||||
}
|
||||
|
||||
|
@ -131,13 +130,14 @@ public class WebConfig extends WebMvcConfigurerAdapter {
|
|||
}
|
||||
|
||||
@Bean
|
||||
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) {
|
||||
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext, RedditTokenService redditTokenService) {
|
||||
final OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext);
|
||||
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
|
||||
list.add(new UserAgentInterceptor());
|
||||
template.setInterceptors(list);
|
||||
final AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
|
||||
final AccessTokenProviderChain accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
|
||||
new ClientCredentialsAccessTokenProvider()));
|
||||
accessTokenProvider.setClientTokenServices(redditTokenService);
|
||||
template.setAccessTokenProvider(accessTokenProvider);
|
||||
return template;
|
||||
}
|
||||
|
|
|
@ -8,4 +8,6 @@ public interface UserRepository extends JpaRepository<User, Long> {
|
|||
User findByUsername(String username);
|
||||
|
||||
User findByAccessToken(String token);
|
||||
|
||||
User findByRememberMeToken(String token);
|
||||
}
|
|
@ -26,6 +26,8 @@ public class User {
|
|||
|
||||
private boolean needCaptcha;
|
||||
|
||||
private String rememberMeToken;
|
||||
|
||||
public User() {
|
||||
super();
|
||||
}
|
||||
|
@ -80,6 +82,14 @@ public class User {
|
|||
|
||||
//
|
||||
|
||||
public String getRememberMeToken() {
|
||||
return rememberMeToken;
|
||||
}
|
||||
|
||||
public void setRememberMeToken(String rememberMeToken) {
|
||||
this.rememberMeToken = rememberMeToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package org.baeldung.persistence.service;
|
||||
|
||||
import org.baeldung.persistence.dao.UserRepository;
|
||||
import org.baeldung.persistence.model.User;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
|
||||
import org.springframework.security.oauth2.client.token.ClientTokenServices;
|
||||
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.common.OAuth2AccessToken;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class RedditTokenService implements ClientTokenServices {
|
||||
|
||||
@Autowired
|
||||
private UserRepository userReopsitory;
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
public RedditTokenService() {
|
||||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OAuth2AccessToken getAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
|
||||
logger.info("reddit ==== getAccessToken");
|
||||
final User user = (User) authentication.getPrincipal();
|
||||
final DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(user.getAccessToken());
|
||||
token.setRefreshToken(new DefaultOAuth2RefreshToken((user.getRefreshToken())));
|
||||
token.setExpiration(user.getTokenExpiration());
|
||||
return token;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication, OAuth2AccessToken accessToken) {
|
||||
logger.info("reddit ==== saveAccessToken");
|
||||
final User user = (User) authentication.getPrincipal();
|
||||
user.setAccessToken(accessToken.getValue());
|
||||
if (accessToken.getRefreshToken() != null) {
|
||||
user.setRefreshToken(accessToken.getRefreshToken().getValue());
|
||||
}
|
||||
user.setTokenExpiration(accessToken.getExpiration());
|
||||
userReopsitory.save(user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
|
||||
logger.info("reddit ==== removeAccessToken");
|
||||
}
|
||||
|
||||
}
|
|
@ -8,6 +8,11 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import javax.servlet.http.Cookie;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.apache.commons.lang.RandomStringUtils;
|
||||
import org.baeldung.persistence.dao.PostRepository;
|
||||
import org.baeldung.persistence.dao.UserRepository;
|
||||
import org.baeldung.persistence.model.Post;
|
||||
|
@ -27,6 +32,7 @@ import org.springframework.stereotype.Controller;
|
|||
import org.springframework.ui.Model;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.bind.annotation.CookieValue;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMethod;
|
||||
|
@ -43,6 +49,7 @@ public class RedditController {
|
|||
|
||||
private static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm");
|
||||
private final SimpleDateFormat dfHour = new SimpleDateFormat("HH");
|
||||
public static final String REMEMBER_ME_COOKIE = "CustomRememberMe";
|
||||
|
||||
@Autowired
|
||||
private OAuth2RestTemplate redditRestTemplate;
|
||||
|
@ -57,9 +64,11 @@ public class RedditController {
|
|||
private RedditClassifier redditClassifier;
|
||||
|
||||
@RequestMapping("/login")
|
||||
public final String redditLogin() {
|
||||
public final String redditLogin(@CookieValue(value = REMEMBER_ME_COOKIE, required = false) String rememberMe, HttpServletRequest request, HttpServletResponse response) {
|
||||
if (!canAutoLogin(rememberMe)) {
|
||||
final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class);
|
||||
loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken());
|
||||
loadAuthentication(node.get("name").asText(), redditRestTemplate.getAccessToken(), response);
|
||||
}
|
||||
return "redirect:home.html";
|
||||
}
|
||||
|
||||
|
@ -221,7 +230,7 @@ public class RedditController {
|
|||
}
|
||||
}
|
||||
|
||||
private final void loadAuthentication(final String name, final OAuth2AccessToken token) {
|
||||
private void loadAuthentication(final String name, final OAuth2AccessToken token, HttpServletResponse response) {
|
||||
User user = userReopsitory.findByUsername(name);
|
||||
if (user == null) {
|
||||
user = new User();
|
||||
|
@ -239,8 +248,35 @@ public class RedditController {
|
|||
user.setTokenExpiration(token.getExpiration());
|
||||
userReopsitory.save(user);
|
||||
|
||||
generateRememberMeToken(user, response);
|
||||
|
||||
final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, token.getValue(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER")));
|
||||
SecurityContextHolder.getContext().setAuthentication(auth);
|
||||
}
|
||||
|
||||
private void generateRememberMeToken(User user, HttpServletResponse response) {
|
||||
String rememberMe = RandomStringUtils.randomAlphanumeric(30);
|
||||
while (userReopsitory.findByRememberMeToken(rememberMe) != null) {
|
||||
rememberMe = RandomStringUtils.randomAlphanumeric(30);
|
||||
}
|
||||
user.setRememberMeToken(rememberMe);
|
||||
userReopsitory.save(user);
|
||||
final Cookie c = new Cookie(REMEMBER_ME_COOKIE, rememberMe);
|
||||
c.setMaxAge(1209600);
|
||||
response.addCookie(c);
|
||||
}
|
||||
|
||||
private boolean canAutoLogin(String rememberMeToken) {
|
||||
if (rememberMeToken != null) {
|
||||
final User user = userReopsitory.findByRememberMeToken(rememberMeToken);
|
||||
if (user != null) {
|
||||
logger.info("Auto Login successfully");
|
||||
final UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken(user, user.getAccessToken(), Arrays.asList(new SimpleGrantedAuthority("ROLE_USER")));
|
||||
SecurityContextHolder.getContext().setAuthentication(auth);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
<p class="navbar-text navbar-right">Logged in as
|
||||
<b><sec:authentication property="principal.username" /></b>
|
||||
<a href="logout">Logout</a>
|
||||
</p>
|
||||
|
||||
<div class="collapse navbar-collapse" id="bs-example-navbar-collapse-1">
|
||||
|
|
Loading…
Reference in New Issue