Polish gh-7840

This commit is contained in:
Joe Grandja 2020-02-24 09:28:00 -05:00
parent 65b5d468fb
commit c6da7b2dd6

View File

@ -17,9 +17,9 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientAuthorizationException;
@ -61,7 +61,6 @@ import reactor.util.context.Context;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.time.Duration; import java.time.Duration;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -559,7 +558,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
Map<String, Object> attrs = request.attributes(); Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs); Authentication authentication = getAuthentication(attrs);
if (authentication == null) { if (authentication == null) {
authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); authentication = createAuthentication(authorizedClient.getPrincipalName());
} }
HttpServletRequest servletRequest = getRequest(attrs); HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs); HttpServletResponse servletResponse = getResponse(attrs);
@ -609,52 +608,20 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME); return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
} }
private static class PrincipalNameAuthentication implements Authentication { private static Authentication createAuthentication(final String principalName) {
private final String principalName; Assert.hasText(principalName, "principalName cannot be empty");
private PrincipalNameAuthentication(String principalName) { return new AbstractAuthenticationToken(null) {
Assert.hasText(principalName, "principalName cannot be empty"); @Override
this.principalName = principalName; public Object getCredentials() {
} return "";
}
@Override @Override
public Collection<? extends GrantedAuthority> getAuthorities() { public Object getPrincipal() {
throw unsupported(); return principalName;
} }
};
@Override
public Object getCredentials() {
throw unsupported();
}
@Override
public Object getDetails() {
throw unsupported();
}
@Override
public Object getPrincipal() {
return getName();
}
@Override
public boolean isAuthenticated() {
throw unsupported();
}
@Override
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
throw unsupported();
}
@Override
public String getName() {
return this.principalName;
}
private UnsupportedOperationException unsupported() {
return new UnsupportedOperationException("Not Supported");
}
} }
/** /**
@ -711,7 +678,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
ClientAuthorizationException authorizationException = new ClientAuthorizationException( ClientAuthorizationException authorizationException = new ClientAuthorizationException(
oauth2Error, authorizedClient.getClientRegistration().getRegistrationId()); oauth2Error, authorizedClient.getClientRegistration().getRegistrationId());
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); Authentication principal = createAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs); HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs); HttpServletResponse servletResponse = getResponse(attrs);
@ -779,7 +746,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
ClientAuthorizationException authorizationException = new ClientAuthorizationException( ClientAuthorizationException authorizationException = new ClientAuthorizationException(
oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception); oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception);
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); Authentication principal = createAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs); HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs); HttpServletResponse servletResponse = getResponse(attrs);
@ -804,7 +771,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return Mono.empty(); return Mono.empty();
} }
Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); Authentication principal = createAuthentication(authorizedClient.getPrincipalName());
HttpServletRequest servletRequest = getRequest(attrs); HttpServletRequest servletRequest = getRequest(attrs);
HttpServletResponse servletResponse = getResponse(attrs); HttpServletResponse servletResponse = getResponse(attrs);