Polish OneTimeTokenLoginConfigurer

Signed-off-by: DingHao <dh.hiekn@gmail.com>
This commit is contained in:
DingHao 2025-01-23 09:51:33 +08:00 committed by Josh Cummings
parent fc19bf8769
commit f7e0f7fa8a

View File

@ -18,7 +18,6 @@ package org.springframework.security.config.annotation.web.configurers.ott;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
@ -91,7 +90,7 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
@Override @Override
public void init(H http) { public void init(H http) {
AuthenticationProvider authenticationProvider = getAuthenticationProvider(http); AuthenticationProvider authenticationProvider = getAuthenticationProvider();
http.authenticationProvider(postProcess(authenticationProvider)); http.authenticationProvider(postProcess(authenticationProvider));
configureDefaultLoginPage(http); configureDefaultLoginPage(http);
} }
@ -138,17 +137,19 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
} }
private void configureOttGenerateFilter(H http) { private void configureOttGenerateFilter(H http) {
GenerateOneTimeTokenFilter generateFilter = new GenerateOneTimeTokenFilter(getOneTimeTokenService(http), GenerateOneTimeTokenFilter generateFilter = new GenerateOneTimeTokenFilter(getOneTimeTokenService(),
getOneTimeTokenGenerationSuccessHandler(http)); getOneTimeTokenGenerationSuccessHandler());
generateFilter.setRequestMatcher(antMatcher(HttpMethod.POST, this.tokenGeneratingUrl)); generateFilter.setRequestMatcher(antMatcher(HttpMethod.POST, this.tokenGeneratingUrl));
generateFilter.setRequestResolver(getGenerateRequestResolver(http)); generateFilter.setRequestResolver(getGenerateRequestResolver());
http.addFilter(postProcess(generateFilter)); http.addFilter(postProcess(generateFilter));
http.addFilter(DefaultResourcesFilter.css()); http.addFilter(DefaultResourcesFilter.css());
} }
private OneTimeTokenGenerationSuccessHandler getOneTimeTokenGenerationSuccessHandler(H http) { private OneTimeTokenGenerationSuccessHandler getOneTimeTokenGenerationSuccessHandler() {
if (this.oneTimeTokenGenerationSuccessHandler == null) { if (this.oneTimeTokenGenerationSuccessHandler == null) {
this.oneTimeTokenGenerationSuccessHandler = getBeanOrNull(http, OneTimeTokenGenerationSuccessHandler.class); this.oneTimeTokenGenerationSuccessHandler = this.context
.getBeanProvider(OneTimeTokenGenerationSuccessHandler.class)
.getIfUnique();
} }
if (this.oneTimeTokenGenerationSuccessHandler == null) { if (this.oneTimeTokenGenerationSuccessHandler == null) {
throw new IllegalStateException(""" throw new IllegalStateException("""
@ -170,12 +171,12 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
http.addFilter(postProcess(submitPage)); http.addFilter(postProcess(submitPage));
} }
private AuthenticationProvider getAuthenticationProvider(H http) { private AuthenticationProvider getAuthenticationProvider() {
if (this.authenticationProvider != null) { if (this.authenticationProvider != null) {
return this.authenticationProvider; return this.authenticationProvider;
} }
UserDetailsService userDetailsService = getContext().getBean(UserDetailsService.class); UserDetailsService userDetailsService = this.context.getBean(UserDetailsService.class);
this.authenticationProvider = new OneTimeTokenAuthenticationProvider(getOneTimeTokenService(http), this.authenticationProvider = new OneTimeTokenAuthenticationProvider(getOneTimeTokenService(),
userDetailsService); userDetailsService);
return this.authenticationProvider; return this.authenticationProvider;
} }
@ -321,44 +322,34 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
return this; return this;
} }
private GenerateOneTimeTokenRequestResolver getGenerateRequestResolver(H http) { private GenerateOneTimeTokenRequestResolver getGenerateRequestResolver() {
if (this.requestResolver != null) { if (this.requestResolver != null) {
return this.requestResolver; return this.requestResolver;
} }
GenerateOneTimeTokenRequestResolver bean = getBeanOrNull(http, GenerateOneTimeTokenRequestResolver.class); this.requestResolver = this.context.getBeanProvider(GenerateOneTimeTokenRequestResolver.class)
this.requestResolver = Objects.requireNonNullElseGet(bean, DefaultGenerateOneTimeTokenRequestResolver::new); .getIfUnique(DefaultGenerateOneTimeTokenRequestResolver::new);
return this.requestResolver; return this.requestResolver;
} }
private OneTimeTokenService getOneTimeTokenService(H http) { private OneTimeTokenService getOneTimeTokenService() {
if (this.oneTimeTokenService != null) { if (this.oneTimeTokenService != null) {
return this.oneTimeTokenService; return this.oneTimeTokenService;
} }
OneTimeTokenService bean = getBeanOrNull(http, OneTimeTokenService.class); this.oneTimeTokenService = this.context.getBeanProvider(OneTimeTokenService.class)
if (bean != null) { .getIfUnique(InMemoryOneTimeTokenService::new);
this.oneTimeTokenService = bean;
}
else {
this.oneTimeTokenService = new InMemoryOneTimeTokenService();
}
return this.oneTimeTokenService; return this.oneTimeTokenService;
} }
private <C> C getBeanOrNull(H http, Class<C> clazz) {
ApplicationContext context = http.getSharedObject(ApplicationContext.class);
if (context == null) {
return null;
}
return context.getBeanProvider(clazz).getIfUnique();
}
private Map<String, String> hiddenInputs(HttpServletRequest request) { private Map<String, String> hiddenInputs(HttpServletRequest request) {
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
return (token != null) ? Collections.singletonMap(token.getParameterName(), token.getToken()) return (token != null) ? Collections.singletonMap(token.getParameterName(), token.getToken())
: Collections.emptyMap(); : Collections.emptyMap();
} }
/**
* @deprecated Use this.context instead
*/
@Deprecated
public ApplicationContext getContext() { public ApplicationContext getContext() {
return this.context; return this.context;
} }