Enable null-safety in spring-security-oauth2-authorization-server

Closes gh-18937
This commit is contained in:
Joe Grandja 2026-03-19 13:56:17 -04:00
parent fe24bd3d0c
commit 1db0d4f83d
166 changed files with 1861 additions and 858 deletions

View File

@ -1,5 +1,6 @@
plugins {
id 'compile-warnings-error'
id 'security-nullability'
}
apply plugin: 'io.spring.convention.spring-module'

View File

@ -23,7 +23,8 @@ import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
/**
@ -90,8 +91,7 @@ public final class InMemoryOAuth2AuthorizationConsentService implements OAuth2Au
}
@Override
@Nullable
public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
public @Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
int id = getId(registeredClientId, principalName);

View File

@ -23,7 +23,8 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@ -125,17 +126,15 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
}
}
@Nullable
@Override
public OAuth2Authorization findById(String id) {
public @Nullable OAuth2Authorization findById(String id) {
Assert.hasText(id, "id cannot be empty");
OAuth2Authorization authorization = this.authorizations.get(id);
return (authorization != null) ? authorization : this.initializedAuthorizations.get(id);
}
@Nullable
@Override
public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
public @Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty");
for (OAuth2Authorization authorization : this.authorizations.values()) {
if (hasToken(authorization, token, tokenType)) {

View File

@ -25,6 +25,8 @@ import java.util.List;
import java.util.Set;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.context.annotation.ImportRuntimeHints;
@ -35,7 +37,6 @@ import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.lang.Nullable;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -162,8 +163,7 @@ public class JdbcOAuth2AuthorizationConsentService implements OAuth2Authorizatio
}
@Override
@Nullable
public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
public @Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
SqlParameterValue[] parameters = new SqlParameterValue[] {
@ -281,7 +281,7 @@ public class JdbcOAuth2AuthorizationConsentService implements OAuth2Authorizatio
static class JdbcOAuth2AuthorizationConsentServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources()
.registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql"));

View File

@ -17,6 +17,7 @@
package org.springframework.security.oauth2.server.authorization;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
@ -36,6 +37,7 @@ import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import tools.jackson.databind.JacksonModule;
import tools.jackson.databind.json.JsonMapper;
@ -54,7 +56,6 @@ import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.jdbc.support.lob.LobCreator;
import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.lang.Nullable;
import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -210,7 +211,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
private static Map<String, ColumnMetadata> columnMetadataMap;
private static final Map<String, ColumnMetadata> columnMetadataMap = new HashMap<>();
private final JdbcOperations jdbcOperations;
@ -292,18 +293,16 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss);
}
@Nullable
@Override
public OAuth2Authorization findById(String id) {
public @Nullable OAuth2Authorization findById(String id) {
Assert.hasText(id, "id cannot be empty");
List<SqlParameterValue> parameters = new ArrayList<>();
parameters.add(new SqlParameterValue(Types.VARCHAR, id));
return findBy(PK_FILTER, parameters);
}
@Nullable
@Override
public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
public @Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty");
List<SqlParameterValue> parameters = new ArrayList<>();
if (tokenType == null) {
@ -347,7 +346,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return null;
}
private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
private @Nullable OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
parameters.toArray());
@ -399,7 +398,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
}
private static void initColumnMetadata(JdbcOperations jdbcOperations) {
columnMetadataMap = new HashMap<>();
columnMetadataMap.clear();
ColumnMetadata columnMetadata;
columnMetadata = getColumnMetadata(jdbcOperations, "attributes", Types.BLOB);
@ -432,32 +431,37 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName,
int defaultDataType) {
Integer dataType = jdbcOperations.execute((ConnectionCallback<Integer>) (conn) -> {
DatabaseMetaData databaseMetaData = conn.getMetaData();
ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
if (rs.next()) {
return rs.getInt("DATA_TYPE");
@Nullable Integer dataType = jdbcOperations.execute(new ConnectionCallback<@Nullable Integer>() {
@Override
public @Nullable Integer doInConnection(Connection conn) throws SQLException {
DatabaseMetaData databaseMetaData = conn.getMetaData();
ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
// NOTE: (Applies to HSQL)
// When a database object is created with one of the CREATE statements or
// renamed with the ALTER statement,
// if the name is enclosed in double quotes, the exact name is used as the
// case-normal form.
// But if it is not enclosed in double quotes,
// the name is converted to uppercase and this uppercase version is stored
// in
// the database as the case-normal form.
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(Locale.ENGLISH),
columnName.toUpperCase(Locale.ENGLISH));
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return null;
}
// NOTE: (Applies to HSQL)
// When a database object is created with one of the CREATE statements or
// renamed with the ALTER statement,
// if the name is enclosed in double quotes, the exact name is used as the
// case-normal form.
// But if it is not enclosed in double quotes,
// the name is converted to uppercase and this uppercase version is stored in
// the database as the case-normal form.
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(Locale.ENGLISH),
columnName.toUpperCase(Locale.ENGLISH));
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return null;
});
return new ColumnMetadata(columnName, (dataType != null) ? dataType : defaultDataType);
}
private static SqlParameterValue mapToSqlParameter(String columnName, String value) {
private static SqlParameterValue mapToSqlParameter(String columnName, @Nullable String value) {
ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
Assert.notNull(columnMetadata, "Column metadata not found for column '" + columnName + "'");
return (Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value))
? new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8))
: new SqlParameterValue(columnMetadata.getDataType(), value);
@ -610,6 +614,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
.equalsIgnoreCase(rs.getString("access_token_type"))) {
tokenType = OAuth2AccessToken.TokenType.DPOP;
}
Assert.notNull(tokenType, "access_token_type must be BEARER or DPOP");
Set<String> scopes = Collections.emptySet();
String accessTokenScopes = rs.getString("access_token_scopes");
@ -627,8 +632,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, OIDC_ID_TOKEN_METADATA));
OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt,
(Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
@SuppressWarnings("unchecked")
Map<String, Object> idTokenClaims = (Map<String, Object>) oidcTokenMetadata
.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME);
if (idTokenClaims == null) {
idTokenClaims = Collections.emptyMap();
}
OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, idTokenClaims);
builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
}
@ -670,9 +680,10 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return builder.build();
}
private String getLobValue(ResultSet rs, String columnName) throws SQLException {
private @Nullable String getLobValue(ResultSet rs, String columnName) throws SQLException {
String columnValue = null;
ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
Assert.notNull(columnMetadata, "Column metadata not found for column '" + columnName + "'");
if (Types.BLOB == columnMetadata.getDataType()) {
byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName);
if (columnValueBytes != null) {
@ -701,7 +712,10 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return this.lobHandler;
}
private Map<String, Object> parseMap(String data) {
private Map<String, Object> parseMap(@Nullable String data) {
if (!StringUtils.hasText(data)) {
return Collections.emptyMap();
}
try {
return readValue(data);
}
@ -849,7 +863,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
}
private <T extends OAuth2Token> List<SqlParameterValue> toSqlParameterList(String tokenColumnName,
String tokenMetadataColumnName, OAuth2Authorization.Token<T> token) {
String tokenMetadataColumnName, OAuth2Authorization.@Nullable Token<T> token) {
List<SqlParameterValue> parameters = new ArrayList<>();
String tokenValue = null;
@ -933,7 +947,8 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
}
@Override
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
protected void doSetValue(PreparedStatement ps, int parameterPosition, @Nullable Object argValue)
throws SQLException {
if (argValue instanceof SqlParameterValue paramValue) {
if (paramValue.getSqlType() == Types.BLOB) {
if (paramValue.getValue() != null) {
@ -983,7 +998,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
static class JdbcOAuth2AuthorizationServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources()
.registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql"));

View File

@ -28,7 +28,8 @@ import java.util.Set;
import java.util.UUID;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@ -58,19 +59,19 @@ public class OAuth2Authorization implements Serializable {
@Serial
private static final long serialVersionUID = 880363144799377926L;
private String id;
private @Nullable String id;
private String registeredClientId;
private @Nullable String registeredClientId;
private String principalName;
private @Nullable String principalName;
private AuthorizationGrantType authorizationGrantType;
private @Nullable AuthorizationGrantType authorizationGrantType;
private Set<String> authorizedScopes;
private @Nullable Set<String> authorizedScopes;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens;
private @Nullable Map<Class<? extends OAuth2Token>, Token<?>> tokens;
private Map<String, Object> attributes;
private @Nullable Map<String, Object> attributes;
protected OAuth2Authorization() {
}
@ -80,6 +81,7 @@ public class OAuth2Authorization implements Serializable {
* @return the identifier for the authorization
*/
public String getId() {
Assert.notNull(this.id, "id cannot be null");
return this.id;
}
@ -88,6 +90,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link RegisteredClient#getId()}
*/
public String getRegisteredClientId() {
Assert.notNull(this.registeredClientId, "registeredClientId cannot be null");
return this.registeredClientId;
}
@ -96,6 +99,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@code Principal} name of the resource owner (or client)
*/
public String getPrincipalName() {
Assert.notNull(this.principalName, "principalName cannot be null");
return this.principalName;
}
@ -105,6 +109,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link AuthorizationGrantType} used for the authorization
*/
public AuthorizationGrantType getAuthorizationGrantType() {
Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
return this.authorizationGrantType;
}
@ -113,14 +118,16 @@ public class OAuth2Authorization implements Serializable {
* @return the {@code Set} of authorized scope(s)
*/
public Set<String> getAuthorizedScopes() {
Assert.notNull(this.authorizedScopes, "authorizedScopes cannot be null");
return this.authorizedScopes;
}
/**
* Returns the {@link Token} of type {@link OAuth2AccessToken}.
* @return the {@link Token} of type {@link OAuth2AccessToken}
* @return the {@link Token} of type {@link OAuth2AccessToken}, or {@code null} if not
* available
*/
public Token<OAuth2AccessToken> getAccessToken() {
public @Nullable Token<OAuth2AccessToken> getAccessToken() {
return getToken(OAuth2AccessToken.class);
}
@ -129,8 +136,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if
* not available
*/
@Nullable
public Token<OAuth2RefreshToken> getRefreshToken() {
public @Nullable Token<OAuth2RefreshToken> getRefreshToken() {
return getToken(OAuth2RefreshToken.class);
}
@ -140,10 +146,10 @@ public class OAuth2Authorization implements Serializable {
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(Class<T> tokenType) {
public <T extends OAuth2Token> @Nullable Token<T> getToken(Class<T> tokenType) {
Assert.notNull(tokenType, "tokenType cannot be null");
Assert.notNull(this.tokens, "tokens cannot be null");
Token<?> token = this.tokens.get(tokenType);
return (token != null) ? (Token<T>) token : null;
}
@ -154,10 +160,10 @@ public class OAuth2Authorization implements Serializable {
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(String tokenValue) {
public <T extends OAuth2Token> @Nullable Token<T> getToken(String tokenValue) {
Assert.hasText(tokenValue, "tokenValue cannot be empty");
Assert.notNull(this.tokens, "tokens cannot be null");
for (Token<?> token : this.tokens.values()) {
if (token.getToken().getTokenValue().equals(tokenValue)) {
return (Token<T>) token;
@ -171,6 +177,7 @@ public class OAuth2Authorization implements Serializable {
* @return a {@code Map} of the attribute(s)
*/
public Map<String, Object> getAttributes() {
Assert.notNull(this.attributes, "attributes cannot be null");
return this.attributes;
}
@ -181,10 +188,10 @@ public class OAuth2Authorization implements Serializable {
* @return the value of an attribute associated to the authorization, or {@code null}
* if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
public <T> @Nullable T getAttribute(String name) {
Assert.hasText(name, "name cannot be empty");
Assert.notNull(this.attributes, "attributes cannot be null");
return (T) this.attributes.get(name);
}
@ -230,6 +237,7 @@ public class OAuth2Authorization implements Serializable {
*/
public static Builder from(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
Assert.notNull(authorization.tokens, "tokens cannot be null");
return new Builder(authorization.getRegisteredClientId()).id(authorization.getId())
.principalName(authorization.getPrincipalName())
.authorizationGrantType(authorization.getAuthorizationGrantType())
@ -324,8 +332,7 @@ public class OAuth2Authorization implements Serializable {
* Returns the claims associated to the token.
* @return a {@code Map} of the claims, or {@code null} if not available
*/
@Nullable
public Map<String, Object> getClaims() {
public @Nullable Map<String, Object> getClaims() {
return getMetadata(CLAIMS_METADATA_NAME);
}
@ -335,9 +342,8 @@ public class OAuth2Authorization implements Serializable {
* @param <V> the value type of the metadata
* @return the value of the metadata, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <V> V getMetadata(String name) {
public <V> @Nullable V getMetadata(String name) {
Assert.hasText(name, "name cannot be empty");
return (V) this.metadata.get(name);
}
@ -380,15 +386,15 @@ public class OAuth2Authorization implements Serializable {
*/
public static class Builder {
private String id;
private @Nullable String id;
private final String registeredClientId;
private String principalName;
private @Nullable String principalName;
private AuthorizationGrantType authorizationGrantType;
private @Nullable AuthorizationGrantType authorizationGrantType;
private Set<String> authorizedScopes;
private @Nullable Set<String> authorizedScopes;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens = new HashMap<>();
@ -503,8 +509,10 @@ public class OAuth2Authorization implements Serializable {
token(token, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
Token<?> accessToken = this.tokens.get(OAuth2AccessToken.class);
token(accessToken.getToken(),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
if (accessToken != null) {
token(accessToken.getToken(),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
}
Token<?> authorizationCode = this.tokens.get(OAuth2AuthorizationCode.class);
if (authorizationCode != null && !authorizationCode.isInvalidated()) {

View File

@ -24,7 +24,6 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import org.springframework.lang.NonNull;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -99,8 +98,9 @@ public final class OAuth2AuthorizationConsent implements Serializable {
public Set<String> getScopes() {
Set<String> authorities = new HashSet<>();
for (GrantedAuthority authority : getAuthorities()) {
if (authority.getAuthority().startsWith(AUTHORITIES_SCOPE_PREFIX)) {
authorities.add(authority.getAuthority().substring(AUTHORITIES_SCOPE_PREFIX.length()));
String authorityValue = authority.getAuthority();
if (authorityValue != null && authorityValue.startsWith(AUTHORITIES_SCOPE_PREFIX)) {
authorities.add(authorityValue.substring(AUTHORITIES_SCOPE_PREFIX.length()));
}
}
return authorities;
@ -146,7 +146,7 @@ public final class OAuth2AuthorizationConsent implements Serializable {
* @param principalName the {@code Principal} name
* @return the {@link Builder}
*/
public static Builder withId(@NonNull String registeredClientId, @NonNull String principalName) {
public static Builder withId(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
return new Builder(registeredClientId, principalName);

View File

@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization;
import java.security.Principal;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
/**
@ -50,7 +51,6 @@ public interface OAuth2AuthorizationConsentService {
* @param principalName the name of the {@link Principal}
* @return the {@link OAuth2AuthorizationConsent} if found, otherwise {@code null}
*/
@Nullable
OAuth2AuthorizationConsent findById(String registeredClientId, String principalName);
@Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName);
}

View File

@ -19,8 +19,11 @@ package org.springframework.security.oauth2.server.authorization;
import java.net.URL;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.util.Assert;
/**
* A {@link ClaimAccessor} for the "claims" an Authorization Server describes about its
@ -57,7 +60,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} the Authorization Server asserts as its Issuer Identifier
*/
default URL getIssuer() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.ISSUER);
URL issuer = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.ISSUER);
Assert.notNull(issuer, "issuer cannot be null");
return issuer;
}
/**
@ -66,7 +71,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} of the OAuth 2.0 Authorization Endpoint
*/
default URL getAuthorizationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.AUTHORIZATION_ENDPOINT);
URL authorizationEndpoint = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.AUTHORIZATION_ENDPOINT);
Assert.notNull(authorizationEndpoint, "authorizationEndpoint cannot be null");
return authorizationEndpoint;
}
/**
@ -74,7 +81,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (pushed_authorization_request_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Pushed Authorization Request Endpoint
*/
default URL getPushedAuthorizationRequestEndpoint() {
default @Nullable URL getPushedAuthorizationRequestEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.PUSHED_AUTHORIZATION_REQUEST_ENDPOINT);
}
@ -83,7 +90,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (device_authorization_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Device Authorization Endpoint
*/
default URL getDeviceAuthorizationEndpoint() {
default @Nullable URL getDeviceAuthorizationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.DEVICE_AUTHORIZATION_ENDPOINT);
}
@ -92,7 +99,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} of the OAuth 2.0 Token Endpoint
*/
default URL getTokenEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT);
URL tokenEndpoint = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT);
Assert.notNull(tokenEndpoint, "tokenEndpoint cannot be null");
return tokenEndpoint;
}
/**
@ -100,7 +109,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (token_endpoint_auth_methods_supported)}.
* @return the client authentication methods supported by the OAuth 2.0 Token Endpoint
*/
default List<String> getTokenEndpointAuthenticationMethods() {
default @Nullable List<String> getTokenEndpointAuthenticationMethods() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED);
}
@ -108,7 +117,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* Returns the {@code URL} of the JSON Web Key Set {@code (jwks_uri)}.
* @return the {@code URL} of the JSON Web Key Set
*/
default URL getJwkSetUrl() {
default @Nullable URL getJwkSetUrl() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.JWKS_URI);
}
@ -116,7 +125,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* Returns the OAuth 2.0 {@code scope} values supported {@code (scopes_supported)}.
* @return the OAuth 2.0 {@code scope} values supported
*/
default List<String> getScopes() {
default @Nullable List<String> getScopes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.SCOPES_SUPPORTED);
}
@ -126,7 +135,10 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the OAuth 2.0 {@code response_type} values supported
*/
default List<String> getResponseTypes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.RESPONSE_TYPES_SUPPORTED);
List<String> responseTypes = getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.RESPONSE_TYPES_SUPPORTED);
Assert.notNull(responseTypes, "responseTypes cannot be null");
return responseTypes;
}
/**
@ -134,7 +146,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (grant_types_supported)}.
* @return the OAuth 2.0 {@code grant_type} values supported
*/
default List<String> getGrantTypes() {
default @Nullable List<String> getGrantTypes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.GRANT_TYPES_SUPPORTED);
}
@ -143,7 +155,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (revocation_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Token Revocation Endpoint
*/
default URL getTokenRevocationEndpoint() {
default @Nullable URL getTokenRevocationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT);
}
@ -153,7 +165,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the client authentication methods supported by the OAuth 2.0 Token
* Revocation Endpoint
*/
default List<String> getTokenRevocationEndpointAuthenticationMethods() {
default @Nullable List<String> getTokenRevocationEndpointAuthenticationMethods() {
return getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT_AUTH_METHODS_SUPPORTED);
}
@ -163,7 +175,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (introspection_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Token Introspection Endpoint
*/
default URL getTokenIntrospectionEndpoint() {
default @Nullable URL getTokenIntrospectionEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT);
}
@ -173,7 +185,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the client authentication methods supported by the OAuth 2.0 Token
* Introspection Endpoint
*/
default List<String> getTokenIntrospectionEndpointAuthenticationMethods() {
default @Nullable List<String> getTokenIntrospectionEndpointAuthenticationMethods() {
return getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT_AUTH_METHODS_SUPPORTED);
}
@ -183,7 +195,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (registration_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Dynamic Client Registration Endpoint
*/
default URL getClientRegistrationEndpoint() {
default @Nullable URL getClientRegistrationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REGISTRATION_ENDPOINT);
}
@ -192,7 +204,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* supported {@code (code_challenge_methods_supported)}.
* @return the {@code code_challenge_method} values supported
*/
default List<String> getCodeChallengeMethods() {
default @Nullable List<String> getCodeChallengeMethods() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.CODE_CHALLENGE_METHODS_SUPPORTED);
}
@ -213,7 +225,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@link JwsAlgorithms JSON Web Signature (JWS) algorithms} supported for
* DPoP Proof JWTs
*/
default List<String> getDPoPSigningAlgorithms() {
default @Nullable List<String> getDPoPSigningAlgorithms() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.DPOP_SIGNING_ALG_VALUES_SUPPORTED);
}

View File

@ -16,7 +16,7 @@
package org.springframework.security.oauth2.server.authorization;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
/**
* Implementations of this interface are responsible for the management of
@ -47,8 +47,7 @@ public interface OAuth2AuthorizationService {
* @param id the authorization identifier
* @return the {@link OAuth2Authorization} if found, otherwise {@code null}
*/
@Nullable
OAuth2Authorization findById(String id);
@Nullable OAuth2Authorization findById(String id);
/**
* Returns the {@link OAuth2Authorization} containing the provided {@code token}, or
@ -57,7 +56,6 @@ public interface OAuth2AuthorizationService {
* @param tokenType the {@link OAuth2TokenType token type}
* @return the {@link OAuth2Authorization} if found, otherwise {@code null}
*/
@Nullable
OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType);
@Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType);
}

View File

@ -20,7 +20,10 @@ import java.net.URL;
import java.time.Instant;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.util.Assert;
/**
* A {@link ClaimAccessor} for the claims that are contained in the OAuth 2.0 Client
@ -41,7 +44,9 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the Client Identifier
*/
default String getClientId() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_ID);
String clientId = getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_ID);
Assert.notNull(clientId, "clientId cannot be null");
return clientId;
}
/**
@ -49,7 +54,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_id_issued_at)}.
* @return the time at which the Client Identifier was issued
*/
default Instant getClientIdIssuedAt() {
default @Nullable Instant getClientIdIssuedAt() {
return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT);
}
@ -57,7 +62,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* Returns the Client Secret {@code (client_secret)}.
* @return the Client Secret
*/
default String getClientSecret() {
default @Nullable String getClientSecret() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_SECRET);
}
@ -66,7 +71,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_secret_expires_at)}.
* @return the time at which the {@code client_secret} will expire
*/
default Instant getClientSecretExpiresAt() {
default @Nullable Instant getClientSecretExpiresAt() {
return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT);
}
@ -75,7 +80,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_name)}.
* @return the name of the Client to be presented to the End-User
*/
default String getClientName() {
default @Nullable String getClientName() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_NAME);
}
@ -84,7 +89,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (redirect_uris)}.
* @return the redirection {@code URI} values used by the Client
*/
default List<String> getRedirectUris() {
default @Nullable List<String> getRedirectUris() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.REDIRECT_URIS);
}
@ -93,7 +98,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (token_endpoint_auth_method)}.
* @return the authentication method used by the Client for the Token Endpoint
*/
default String getTokenEndpointAuthenticationMethod() {
default @Nullable String getTokenEndpointAuthenticationMethod() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
}
@ -103,7 +108,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code grant_type} values that the Client will restrict
* itself to using
*/
default List<String> getGrantTypes() {
default @Nullable List<String> getGrantTypes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.GRANT_TYPES);
}
@ -113,7 +118,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code response_type} values that the Client will restrict
* itself to using
*/
default List<String> getResponseTypes() {
default @Nullable List<String> getResponseTypes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES);
}
@ -123,7 +128,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to
* using
*/
default List<String> getScopes() {
default @Nullable List<String> getScopes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.SCOPE);
}
@ -131,7 +136,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* Returns the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}.
* @return the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
*/
default URL getJwkSetUrl() {
default @Nullable URL getJwkSetUrl() {
return getClaimAsURL(OAuth2ClientMetadataClaimNames.JWKS_URI);
}

View File

@ -20,6 +20,8 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.BindingReflectionHintsRegistrar;
import org.springframework.aot.hint.MemberCategory;
@ -84,7 +86,7 @@ class OAuth2AuthorizationServerBeanRegistrationAotProcessor implements BeanRegis
private boolean jacksonContributed;
@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
public @Nullable BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
boolean isJdbcBasedOAuth2AuthorizationService = JdbcOAuth2AuthorizationService.class
.isAssignableFrom(registeredBean.getBeanClass());

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.aot.hint;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
@ -35,7 +37,7 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2Author
class OAuth2AuthorizationServerRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.reflection()
.registerType(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
MemberCategory.INVOKE_DECLARED_METHODS);

View File

@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Spring Framework AOT {@link org.springframework.aot.hint.RuntimeHints} for GraalVM
* native images for the authorization server module.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.aot.hint;
import org.jspecify.annotations.NullMarked;

View File

@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -48,9 +49,9 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
private final Authentication principal;
private final String redirectUri;
private final @Nullable String redirectUri;
private final String state;
private final @Nullable String state;
private final Set<String> scopes;
@ -103,8 +104,7 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
* Returns the redirect uri.
* @return the redirect uri
*/
@Nullable
public String getRedirectUri() {
public @Nullable String getRedirectUri() {
return this.redirectUri;
}
@ -112,8 +112,7 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
* Returns the state.
* @return the state
*/
@Nullable
public String getState() {
public @Nullable String getState() {
return this.state;
}

View File

@ -20,6 +20,7 @@ import java.time.Instant;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationProvider;
@ -92,7 +93,7 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
// @formatter:off
@ -105,7 +106,7 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
throw invalidClientException(OAuth2ParameterNames.CLIENT_ID);
}
if (this.logger.isTraceEnabled()) {
@ -114,26 +115,27 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method");
throw invalidClientException("authentication_method");
}
if (clientAuthentication.getCredentials() == null) {
throwInvalidClient("credentials");
Object credentials = clientAuthentication.getCredentials();
if (credentials == null) {
throw invalidClientException("credentials");
}
String clientSecret = clientAuthentication.getCredentials().toString();
String clientSecret = credentials.toString();
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format(
"Invalid request: client_secret does not match" + " for registered client '%s'",
registeredClient.getId()));
}
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
throw invalidClientException(OAuth2ParameterNames.CLIENT_SECRET);
}
if (registeredClient.getClientSecretExpiresAt() != null
&& Instant.now().isAfter(registeredClient.getClientSecretExpiresAt())) {
throwInvalidClient("client_secret_expires_at");
throw invalidClientException("client_secret_expires_at");
}
if (this.passwordEncoder.upgradeEncoding(registeredClient.getClientSecret())) {
@ -164,10 +166,10 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
}
private static void throwInvalidClient(String parameterName) {
private static OAuth2AuthenticationException invalidClientException(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error);
return new OAuth2AuthenticationException(error);
}
}

View File

@ -24,6 +24,7 @@ import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -65,7 +66,7 @@ final class CodeVerifierAuthenticator {
void authenticateRequired(OAuth2ClientAuthenticationToken clientAuthentication, RegisteredClient registeredClient) {
if (!authenticate(clientAuthentication, registeredClient)) {
throwInvalidGrant(PkceParameterNames.CODE_VERIFIER);
throw invalidGrantException(PkceParameterNames.CODE_VERIFIER);
}
}
@ -82,10 +83,11 @@ final class CodeVerifierAuthenticator {
return false;
}
OAuth2Authorization authorization = this.authorizationService
.findByToken((String) parameters.get(OAuth2ParameterNames.CODE), AUTHORIZATION_CODE_TOKEN_TYPE);
String code = (String) parameters.get(OAuth2ParameterNames.CODE);
Assert.hasText(code, "code cannot be empty");
OAuth2Authorization authorization = this.authorizationService.findByToken(code, AUTHORIZATION_CODE_TOKEN_TYPE);
if (authorization == null) {
throwInvalidGrant(OAuth2ParameterNames.CODE);
throw invalidGrantException(OAuth2ParameterNames.CODE);
}
if (this.logger.isTraceEnabled()) {
@ -94,6 +96,7 @@ final class CodeVerifierAuthenticator {
OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
String codeChallenge = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE);
@ -105,7 +108,7 @@ final class CodeVerifierAuthenticator {
"Invalid request: code_challenge is required" + " for registered client '%s'",
registeredClient.getId()));
}
throwInvalidGrant(PkceParameterNames.CODE_CHALLENGE);
throw invalidGrantException(PkceParameterNames.CODE_CHALLENGE);
}
else {
if (this.logger.isTraceEnabled()) {
@ -119,6 +122,7 @@ final class CodeVerifierAuthenticator {
this.logger.trace("Validated code verifier parameters");
}
Assert.hasText(codeChallenge, "codeChallenge cannot be empty");
String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
@ -127,7 +131,7 @@ final class CodeVerifierAuthenticator {
"Invalid request: code_verifier is missing or invalid" + " for registered client '%s'",
registeredClient.getId()));
}
throwInvalidGrant(PkceParameterNames.CODE_VERIFIER);
throw invalidGrantException(PkceParameterNames.CODE_VERIFIER);
}
if (this.logger.isTraceEnabled()) {
@ -143,12 +147,13 @@ final class CodeVerifierAuthenticator {
return false;
}
if (!StringUtils.hasText((String) parameters.get(OAuth2ParameterNames.CODE))) {
throwInvalidGrant(OAuth2ParameterNames.CODE);
throw invalidGrantException(OAuth2ParameterNames.CODE);
}
return true;
}
private boolean codeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
private boolean codeVerifierValid(@Nullable String codeVerifier, String codeChallenge,
@Nullable String codeChallengeMethod) {
if (!StringUtils.hasText(codeVerifier)) {
return false;
}
@ -169,10 +174,10 @@ final class CodeVerifierAuthenticator {
return false;
}
private static void throwInvalidGrant(String parameterName) {
private static OAuth2AuthenticationException invalidGrantException(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT,
"Client authentication failed: " + parameterName, null);
throw new OAuth2AuthenticationException(error);
return new OAuth2AuthenticationException(error);
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.authentication;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@ -24,6 +26,7 @@ import org.springframework.security.oauth2.jwt.DPoPProofJwtDecoderFactory;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
@ -42,14 +45,17 @@ final class DPoPProofVerifier {
private DPoPProofVerifier() {
}
static Jwt verifyIfAvailable(OAuth2AuthorizationGrantAuthenticationToken authorizationGrantAuthentication) {
static @Nullable Jwt verifyIfAvailable(
OAuth2AuthorizationGrantAuthenticationToken authorizationGrantAuthentication) {
String dPoPProof = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_proof");
if (!StringUtils.hasText(dPoPProof)) {
return null;
}
String method = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_method");
Assert.hasText(method, "dpop_method cannot be empty");
String targetUri = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_target_uri");
Assert.hasText(targetUri, "dpop_target_uri cannot be empty");
Jwt dPoPProofJwt;
try {

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -82,7 +83,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) {
@ -92,7 +93,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
throw invalidClientException(OAuth2ParameterNames.CLIENT_ID);
}
if (this.logger.isTraceEnabled()) {
@ -102,21 +103,22 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
// @formatter:off
if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) &&
!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) {
throwInvalidClient("authentication_method");
throw invalidClientException("authentication_method");
}
// @formatter:on
if (clientAuthentication.getCredentials() == null) {
throwInvalidClient("credentials");
Object credentials = clientAuthentication.getCredentials();
if (credentials == null) {
throw invalidClientException("credentials");
}
Jwt jwtAssertion = null;
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient);
try {
jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString());
jwtAssertion = jwtDecoder.decode(credentials.toString());
}
catch (JwtException ex) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION, ex);
throw invalidClientException(OAuth2ParameterNames.CLIENT_ASSERTION, ex);
}
if (this.logger.isTraceEnabled()) {
@ -159,14 +161,15 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
this.jwtDecoderFactory = jwtDecoderFactory;
}
private static void throwInvalidClient(String parameterName) {
throwInvalidClient(parameterName, null);
private static OAuth2AuthenticationException invalidClientException(String parameterName) {
return invalidClientException(parameterName, null);
}
private static void throwInvalidClient(String parameterName, Throwable cause) {
private static OAuth2AuthenticationException invalidClientException(String parameterName,
@Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause);
return new OAuth2AuthenticationException(error, error.toString(), cause);
}
}

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler;
import org.springframework.util.Assert;
@ -47,9 +48,8 @@ public final class OAuth2AccessTokenAuthenticationContext implements OAuth2Authe
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -65,7 +65,9 @@ public final class OAuth2AccessTokenAuthenticationContext implements OAuth2Authe
* @return the {@link OAuth2AccessTokenResponse.Builder}
*/
public OAuth2AccessTokenResponse.Builder getAccessTokenResponse() {
return get(OAuth2AccessTokenResponse.Builder.class);
OAuth2AccessTokenResponse.Builder accessTokenResponse = get(OAuth2AccessTokenResponse.Builder.class);
Assert.notNull(accessTokenResponse, "accessTokenResponse cannot be null");
return accessTokenResponse;
}
/**

View File

@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Collections;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -52,7 +53,7 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken;
private final @Nullable OAuth2RefreshToken refreshToken;
private final Map<String, Object> additionalParameters;
@ -135,8 +136,7 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
* Returns the {@link OAuth2RefreshToken refresh token}.
* @return the {@link OAuth2RefreshToken} or {@code null} if not available
*/
@Nullable
public OAuth2RefreshToken getRefreshToken() {
public @Nullable OAuth2RefreshToken getRefreshToken() {
return this.refreshToken;
}

View File

@ -20,6 +20,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.context.Context;
@ -42,7 +44,9 @@ public interface OAuth2AuthenticationContext extends Context {
*/
@SuppressWarnings("unchecked")
default <T extends Authentication> T getAuthentication() {
return (T) get(Authentication.class);
Authentication authentication = get(Authentication.class);
Assert.notNull(authentication, "authentication cannot be null");
return (T) authentication;
}
/**
@ -85,7 +89,7 @@ public interface OAuth2AuthenticationContext extends Context {
}
@SuppressWarnings("unchecked")
protected <V> V get(Object key) {
protected <V> @Nullable V get(Object key) {
return (V) getContext().get(key);
}

View File

@ -43,8 +43,9 @@ final class OAuth2AuthenticationProviderUtils {
static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) {
OAuth2ClientAuthenticationToken clientPrincipal = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal();
Object principal = authentication.getPrincipal();
if (principal != null && OAuth2ClientAuthenticationToken.class.isAssignableFrom(principal.getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) principal;
}
if (clientPrincipal != null && clientPrincipal.isAuthenticated()) {
return clientPrincipal;

View File

@ -30,6 +30,7 @@ import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationProvider;
@ -96,7 +97,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private SessionRegistry sessionRegistry;
private @Nullable SessionRegistry sessionRegistry;
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the
@ -119,6 +120,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(authorizationCodeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -136,9 +138,11 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
.getToken(OAuth2AuthorizationCode.class);
Assert.notNull(authorizationCode, "authorizationCode cannot be null");
OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
if (!authorizationCode.isInvalidated()) {
@ -193,6 +197,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
}
Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
@ -331,10 +336,14 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
this.sessionRegistry = sessionRegistry;
}
private SessionInformation getSessionInformation(Authentication principal) {
private @Nullable SessionInformation getSessionInformation(Authentication principal) {
SessionInformation sessionInformation = null;
if (this.sessionRegistry != null) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), false);
Object sessionPrincipal = principal.getPrincipal();
if (sessionPrincipal == null) {
return null;
}
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(sessionPrincipal, false);
if (!CollectionUtils.isEmpty(sessions)) {
sessionInformation = sessions.get(0);
if (sessions.size() > 1) {

View File

@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
@ -38,7 +39,7 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends OAuth2Authorizat
private final String code;
private final String redirectUri;
private final @Nullable String redirectUri;
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided
@ -68,8 +69,7 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends OAuth2Authorizat
* Returns the redirect uri.
* @return the redirect uri
*/
@Nullable
public String getRedirectUri() {
public @Nullable String getRedirectUri() {
return this.redirectUri;
}

View File

@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Instant;
import java.util.Base64;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -42,10 +43,9 @@ final class OAuth2AuthorizationCodeGenerator implements OAuth2TokenGenerator<OAu
private final StringKeyGenerator authorizationCodeGenerator = new Base64StringKeyGenerator(
Base64.getUrlEncoder().withoutPadding(), 96);
@Nullable
@Override
public OAuth2AuthorizationCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null || !OAuth2ParameterNames.CODE.equals(context.getTokenType().getValue())) {
public @Nullable OAuth2AuthorizationCode generate(OAuth2TokenContext context) {
if (!OAuth2ParameterNames.CODE.equals(context.getTokenType().getValue())) {
return null;
}
Instant issuedAt = Instant.now();

View File

@ -22,7 +22,8 @@ import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@ -50,9 +51,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -67,15 +67,16 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**
* Returns the {@link OAuth2AuthorizationRequest authorization request}.
* @return the {@link OAuth2AuthorizationRequest}
*/
@Nullable
public OAuth2AuthorizationRequest getAuthorizationRequest() {
public @Nullable OAuth2AuthorizationRequest getAuthorizationRequest() {
return get(OAuth2AuthorizationRequest.class);
}
@ -83,8 +84,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
* Returns the {@link OAuth2AuthorizationConsent authorization consent}.
* @return the {@link OAuth2AuthorizationConsent}
*/
@Nullable
public OAuth2AuthorizationConsent getAuthorizationConsent() {
public @Nullable OAuth2AuthorizationConsent getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class);
}

View File

@ -16,7 +16,8 @@
package org.springframework.security.oauth2.server.authorization.authentication;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -33,7 +34,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
*/
public class OAuth2AuthorizationCodeRequestAuthenticationException extends OAuth2AuthenticationException {
private final OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication;
private final @Nullable OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication;
/**
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationException} using
@ -67,8 +68,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationException extends OAuth
* (or Consent), or {@code null} if not available.
* @return the {@link OAuth2AuthorizationCodeRequestAuthenticationToken}
*/
@Nullable
public OAuth2AuthorizationCodeRequestAuthenticationToken getAuthorizationCodeRequestAuthentication() {
public @Nullable OAuth2AuthorizationCodeRequestAuthenticationToken getAuthorizationCodeRequestAuthentication() {
return this.authorizationCodeRequestAuthentication;
}

View File

@ -30,6 +30,7 @@ import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
@ -129,19 +130,19 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
.get(OAuth2ParameterNames.REQUEST_URI);
if (StringUtils.hasText(requestUri)) {
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = null;
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri;
try {
pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri.parse(requestUri);
}
catch (Exception ex) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null);
}
pushedAuthorization = this.authorizationService.findByToken(pushedAuthorizationRequestUri.getState(),
STATE_TOKEN_TYPE);
if (pushedAuthorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null);
}
@ -151,9 +152,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
OAuth2AuthorizationRequest authorizationRequest = pushedAuthorization
.getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
if (!authorizationCodeRequestAuthentication.getClientId().equals(authorizationRequest.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, null);
}
@ -165,7 +167,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
.warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'",
authorizationRequest.getClientId()));
}
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null);
}
@ -179,7 +181,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(authorizationCodeRequestAuthentication.getClientId());
if (registeredClient == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, null);
}
@ -233,11 +235,12 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (!isPrincipalAuthenticated(principal)) {
if (promptValues.contains(OidcPrompt.NONE)) {
throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
throw createException("login_required", "prompt", authorizationCodeRequestAuthentication,
registeredClient);
}
else {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication,
registeredClient);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, "principal",
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -260,7 +263,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
if (promptValues.contains(OidcPrompt.NONE)) {
// Return an error instead of displaying the consent page
throwError("consent_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
throw createException("consent_required", "prompt", authorizationCodeRequestAuthentication,
registeredClient);
}
String state = DEFAULT_STATE_GENERATOR.generateKey();
@ -416,15 +420,17 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {
return false;
}
OAuth2AuthorizationRequest authorizationRequest = authenticationContext.getAuthorizationRequest();
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
// 'openid' scope does not require consent
if (authenticationContext.getAuthorizationRequest().getScopes().contains(OidcScopes.OPENID)
&& authenticationContext.getAuthorizationRequest().getScopes().size() == 1) {
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)
&& authorizationRequest.getScopes().size() == 1) {
return false;
}
if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
.getScopes()
.containsAll(authenticationContext.getAuthorizationRequest().getScopes())) {
.containsAll(authorizationRequest.getScopes())) {
return false;
}
@ -442,7 +448,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
private static OAuth2TokenContext createAuthorizationCodeTokenContext(
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2Authorization authorization, Set<String> authorizedScopes) {
RegisteredClient registeredClient, @Nullable OAuth2Authorization authorization,
Set<String> authorizedScopes) {
// @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
@ -467,23 +474,27 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
&& principal.isAuthenticated();
}
private static void throwError(String errorCode, String parameterName,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) {
throwError(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication, registeredClient, null);
@Nullable RegisteredClient registeredClient) {
return createException(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication,
registeredClient, null);
}
private static void throwError(String errorCode, String parameterName, String errorUri,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName, String errorUri,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) {
@Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
throwError(error, parameterName, authorizationCodeRequestAuthentication, registeredClient,
return createException(error, parameterName, authorizationCodeRequestAuthentication, registeredClient,
authorizationRequest);
}
private static void throwError(OAuth2Error error, String parameterName,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) {
@Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
String redirectUri = resolveRedirectUri(authorizationCodeRequestAuthentication, authorizationRequest,
registeredClient);
@ -500,13 +511,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
authorizationCodeRequestAuthentication.getState(), authorizationCodeRequestAuthentication.getScopes(),
authorizationCodeRequestAuthentication.getAdditionalParameters());
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult);
}
private static String resolveRedirectUri(
private static @Nullable String resolveRedirectUri(
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
OAuth2AuthorizationRequest authorizationRequest, RegisteredClient registeredClient) {
@Nullable OAuth2AuthorizationRequest authorizationRequest, @Nullable RegisteredClient registeredClient) {
if (authorizationCodeRequestAuthentication != null
&& StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) {

View File

@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.util.Assert;
@ -40,7 +41,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
@Serial
private static final long serialVersionUID = -1946164725241393094L;
private final OAuth2AuthorizationCode authorizationCode;
private final @Nullable OAuth2AuthorizationCode authorizationCode;
private boolean validated;
@ -86,8 +87,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
* Returns the {@link OAuth2AuthorizationCode}.
* @return the {@link OAuth2AuthorizationCode}
*/
@Nullable
public OAuth2AuthorizationCode getAuthorizationCode() {
public @Nullable OAuth2AuthorizationCode getAuthorizationCode() {
return this.authorizationCode;
}

View File

@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication;
@ -104,7 +105,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
"Invalid request: requested grant_type is not allowed for registered client '%s'",
registeredClient.getId()));
}
throwError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID,
throw createException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -130,7 +131,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
LOGGER.debug(LogMessage.format("Invalid request: redirect_uri is missing or contains a fragment"
+ " for registered client '%s'", registeredClient.getId()));
}
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient);
}
@ -140,7 +141,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
// When comparing client redirect URIs against pre-registered URIs,
// authorization servers MUST utilize exact string matching.
if (!registeredClient.getRedirectUris().contains(requestedRedirectUri)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -166,7 +167,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
"Invalid request: redirect_uri does not match for registered client '%s'",
registeredClient.getId()));
}
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -178,7 +179,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)
|| registeredClient.getRedirectUris().size() != 1) {
// redirect_uri is REQUIRED for OpenID Connect
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -197,7 +198,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
LogMessage.format("Invalid request: requested scope is not allowed for registered client '%s'",
registeredClient.getId()));
}
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -215,12 +216,12 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
String codeChallengeMethod = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (!StringUtils.hasText(codeChallengeMethod) || !"S256".equals(codeChallengeMethod)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI,
authorizationCodeRequestAuthentication, registeredClient);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD,
PKCE_ERROR_URI, authorizationCodeRequestAuthentication, registeredClient);
}
}
else if (registeredClient.getClientSettings().isRequireProofKey()) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI,
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI,
authorizationCodeRequestAuthentication, registeredClient);
}
}
@ -239,15 +240,15 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
if (promptValues.contains(OidcPrompt.NONE)) {
if (promptValues.contains(OidcPrompt.LOGIN) || promptValues.contains(OidcPrompt.CONSENT)
|| promptValues.contains(OidcPrompt.SELECT_ACCOUNT)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt", authorizationCodeRequestAuthentication,
registeredClient);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, "prompt",
authorizationCodeRequestAuthentication, registeredClient);
}
}
}
}
}
private static boolean isLoopbackAddress(String host) {
private static boolean isLoopbackAddress(@Nullable String host) {
if (!StringUtils.hasText(host)) {
return false;
}
@ -273,20 +274,24 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
}
}
private static void throwError(String errorCode, String parameterName,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) {
throwError(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication, registeredClient);
return createException(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication,
registeredClient);
}
private static void throwError(String errorCode, String parameterName, String errorUri,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName, String errorUri,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
throwError(error, parameterName, authorizationCodeRequestAuthentication, registeredClient);
return createException(error, parameterName, authorizationCodeRequestAuthentication, registeredClient);
}
private static void throwError(OAuth2Error error, String parameterName,
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) {
@ -306,7 +311,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
authorizationCodeRequestAuthentication.getAdditionalParameters());
authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult);
}

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
@ -50,9 +51,8 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -68,7 +68,9 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link OAuth2AuthorizationConsent.Builder}
*/
public OAuth2AuthorizationConsent.Builder getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.Builder.class);
OAuth2AuthorizationConsent.Builder authorizationConsentBuilder = get(OAuth2AuthorizationConsent.Builder.class);
Assert.notNull(authorizationConsentBuilder, "authorizationConsentBuilder cannot be null");
return authorizationConsentBuilder;
}
/**
@ -76,7 +78,9 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**
@ -84,14 +88,16 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
}
/**
* Returns the {@link OAuth2AuthorizationRequest authorization request}.
* @return the {@link OAuth2AuthorizationRequest}
*/
public OAuth2AuthorizationRequest getAuthorizationRequest() {
public @Nullable OAuth2AuthorizationRequest getAuthorizationRequest() {
return get(OAuth2AuthorizationRequest.class);
}

View File

@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider;
@ -79,7 +80,7 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
private Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
private @Nullable Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
/**
* Constructs an {@code OAuth2AuthorizationConsentAuthenticationProvider} using the
@ -100,7 +101,7 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
if (authentication instanceof OAuth2DeviceAuthorizationConsentAuthenticationToken) {
// This is NOT an OAuth 2.0 Authorization Consent for the Authorization Code
// Grant,
@ -114,8 +115,8 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
OAuth2Authorization authorization = this.authorizationService
.findByToken(authorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE);
if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE, authorizationConsentAuthentication,
null, null);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE,
authorizationConsentAuthentication, null, null);
}
if (this.logger.isTraceEnabled()) {
@ -125,14 +126,18 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
// The 'in-flight' authorization must be associated to the current principal
Authentication principal = (Authentication) authorizationConsentAuthentication.getPrincipal();
if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE, authorizationConsentAuthentication,
null, null);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE,
authorizationConsentAuthentication, null, null);
}
RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(authorizationConsentAuthentication.getClientId());
if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
if (registeredClient == null) {
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, null, null);
}
if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, registeredClient, null);
}
@ -142,11 +147,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
Set<String> requestedScopes = authorizationRequest.getScopes();
Set<String> authorizedScopes = new HashSet<>(authorizationConsentAuthentication.getScopes());
if (!requestedScopes.containsAll(authorizedScopes)) {
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE, authorizationConsentAuthentication,
registeredClient, authorizationRequest);
throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
authorizationConsentAuthentication, registeredClient, authorizationRequest);
}
if (this.logger.isTraceEnabled()) {
@ -215,12 +221,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
if (this.logger.isTraceEnabled()) {
this.logger.trace("Removed authorization");
}
throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID,
throw createException(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, registeredClient, authorizationRequest);
}
OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
if (!authorizationConsent.equals(currentAuthorizationConsent)) {
if (currentAuthorizationConsent == null || !authorizationConsent.equals(currentAuthorizationConsent)) {
this.authorizationConsentService.save(authorizationConsent);
if (this.logger.isTraceEnabled()) {
this.logger.trace("Saved authorization consent");
@ -334,16 +340,17 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
&& principal.isAuthenticated();
}
private static void throwError(String errorCode, String parameterName,
OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) {
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName, OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
@Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI);
throwError(error, parameterName, authorizationConsentAuthentication, registeredClient, authorizationRequest);
return createException(error, parameterName, authorizationConsentAuthentication, registeredClient,
authorizationRequest);
}
private static void throwError(OAuth2Error error, String parameterName,
OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) {
private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
String parameterName, OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
@Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
String redirectUri = resolveRedirectUri(authorizationRequest, registeredClient);
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
@ -363,12 +370,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
(Authentication) authorizationConsentAuthentication.getPrincipal(), redirectUri, state, requestedScopes,
null);
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult);
}
private static String resolveRedirectUri(OAuth2AuthorizationRequest authorizationRequest,
RegisteredClient registeredClient) {
private static @Nullable String resolveRedirectUri(@Nullable OAuth2AuthorizationRequest authorizationRequest,
@Nullable RegisteredClient registeredClient) {
if (authorizationRequest != null && StringUtils.hasText(authorizationRequest.getRedirectUri())) {
return authorizationRequest.getRedirectUri();
}

View File

@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;

View File

@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
@ -45,9 +46,8 @@ public final class OAuth2ClientAuthenticationContext implements OAuth2Authentica
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -62,7 +62,9 @@ public final class OAuth2ClientAuthenticationContext implements OAuth2Authentica
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**

View File

@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Collections;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.Transient;
@ -49,11 +50,11 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
private final String clientId;
private final RegisteredClient registeredClient;
private final @Nullable RegisteredClient registeredClient;
private final ClientAuthenticationMethod clientAuthenticationMethod;
private final Object credentials;
private final @Nullable Object credentials;
private final Map<String, Object> additionalParameters;
@ -103,9 +104,8 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
return this.clientId;
}
@Nullable
@Override
public Object getCredentials() {
public @Nullable Object getCredentials() {
return this.credentials;
}
@ -115,8 +115,7 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
* @return the authenticated {@link RegisteredClient}, or {@code null} if not
* authenticated
*/
@Nullable
public RegisteredClient getRegisteredClient() {
public @Nullable RegisteredClient getRegisteredClient() {
return this.registeredClient;
}

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
@ -45,9 +46,8 @@ public final class OAuth2ClientCredentialsAuthenticationContext implements OAuth
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -62,7 +62,9 @@ public final class OAuth2ClientCredentialsAuthenticationContext implements OAuth
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**

View File

@ -95,6 +95,7 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(clientCredentialsAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");

View File

@ -21,7 +21,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;

View File

@ -21,10 +21,12 @@ import java.net.URISyntaxException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider;
@ -138,6 +140,7 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
}
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "accessToken cannot be null");
if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
}
@ -199,9 +202,10 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
private OAuth2ClientRegistrationAuthenticationToken registerClient(
OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) {
@Nullable OAuth2Authorization authorization) {
if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getRedirectUris())) {
List<String> redirectUris = clientRegistrationAuthentication.getClientRegistration().getRedirectUris();
if (!isValidRedirectUris((redirectUris != null) ? redirectUris : Collections.emptyList())) {
throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI,
OAuth2ClientMetadataClaimNames.REDIRECT_URIS);
}
@ -236,8 +240,10 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
if (authorization != null) {
// Invalidate the "initial" access token as it can only be used once
OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getAccessToken();
Assert.notNull(accessToken, "accessToken cannot be null");
OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken());
.invalidate(accessToken.getToken());
if (authorization.getRefreshToken() != null) {
builder.invalidate(authorization.getRefreshToken().getToken());
}
@ -265,8 +271,9 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) {
Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE);
Map<String, Object> claims = authorizedAccessToken.getClaims();
if (claims != null && claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
}
if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);

View File

@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.io.Serial;
import java.util.Collections;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration;
@ -40,8 +41,7 @@ public class OAuth2ClientRegistrationAuthenticationToken extends AbstractAuthent
@Serial
private static final long serialVersionUID = 7135429161909989115L;
@Nullable
private final Authentication principal;
private final @Nullable Authentication principal;
private final OAuth2ClientRegistration clientRegistration;
@ -62,9 +62,8 @@ public class OAuth2ClientRegistrationAuthenticationToken extends AbstractAuthent
}
}
@Nullable
@Override
public Object getPrincipal() {
public @Nullable Object getPrincipal() {
return this.principal;
}

View File

@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider;
@ -72,7 +73,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
private final OAuth2AuthorizationConsentService authorizationConsentService;
private Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
private @Nullable Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
/**
* Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationProvider} using
@ -99,7 +100,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
OAuth2Authorization authorization = this.authorizationService
.findByToken(deviceAuthorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE);
if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
}
if (this.logger.isTraceEnabled()) {
@ -109,13 +110,13 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
// The authorization must be associated to the current principal
Authentication principal = (Authentication) deviceAuthorizationConsentAuthentication.getPrincipal();
if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
}
RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(deviceAuthorizationConsentAuthentication.getClientId());
if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
if (this.logger.isTraceEnabled()) {
@ -123,9 +124,10 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
}
Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
Assert.notNull(requestedScopes, "requestedScopes cannot be null");
Set<String> authorizedScopes = new HashSet<>(deviceAuthorizationConsentAuthentication.getScopes());
if (!requestedScopes.containsAll(authorizedScopes)) {
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE);
throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE);
}
if (this.logger.isTraceEnabled()) {
@ -177,7 +179,9 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
authorizationConsentBuilder.authorities(authorities::addAll);
OAuth2Authorization.Token<OAuth2DeviceCode> deviceCodeToken = authorization.getToken(OAuth2DeviceCode.class);
Assert.notNull(deviceCodeToken, "deviceCode cannot be null");
OAuth2Authorization.Token<OAuth2UserCode> userCodeToken = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCodeToken, "userCode cannot be null");
if (authorities.isEmpty()) {
// Authorization consent denied (or revoked)
@ -196,11 +200,11 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
if (this.logger.isTraceEnabled()) {
this.logger.trace("Invalidated device code and user code because authorization consent was denied");
}
throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID);
throw createException(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID);
}
OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
if (!authorizationConsent.equals(currentAuthorizationConsent)) {
if (currentAuthorizationConsent == null || !authorizationConsent.equals(currentAuthorizationConsent)) {
this.authorizationConsentService.save(authorizationConsent);
if (this.logger.isTraceEnabled()) {
this.logger.trace("Saved authorization consent");
@ -263,9 +267,9 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
&& principal.isAuthenticated();
}
private static void throwError(String errorCode, String parameterName) {
private static OAuth2AuthenticationException createException(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error);
return new OAuth2AuthenticationException(error);
}
}

View File

@ -22,7 +22,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -43,7 +44,7 @@ public class OAuth2DeviceAuthorizationConsentAuthenticationToken extends OAuth2A
private final String userCode;
private final Set<String> requestedScopes;
private final @Nullable Set<String> requestedScopes;
/**
* Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationToken} using the
@ -98,9 +99,9 @@ public class OAuth2DeviceAuthorizationConsentAuthenticationToken extends OAuth2A
/**
* Returns the requested scopes.
* @return the requested scopes
* @return the requested scopes, or {@code null} if not available
*/
public Set<String> getRequestedScopes() {
public @Nullable Set<String> getRequestedScopes() {
return this.requestedScopes;
}

View File

@ -23,9 +23,9 @@ import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
@ -101,6 +101,7 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(deviceAuthorizationRequestAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -224,9 +225,8 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
private final StringKeyGenerator deviceCodeGenerator = new Base64StringKeyGenerator(
Base64.getUrlEncoder().withoutPadding(), 96);
@Nullable
@Override
public OAuth2DeviceCode generate(OAuth2TokenContext context) {
public @Nullable OAuth2DeviceCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null
|| !OAuth2ParameterNames.DEVICE_CODE.equals(context.getTokenType().getValue())) {
return null;
@ -268,9 +268,8 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
private final StringKeyGenerator userCodeGenerator = new UserCodeStringKeyGenerator();
@Nullable
@Override
public OAuth2UserCode generate(OAuth2TokenContext context) {
public @Nullable OAuth2UserCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null
|| !OAuth2ParameterNames.USER_CODE.equals(context.getTokenType().getValue())) {
return null;

View File

@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2DeviceCode;
@ -47,13 +48,13 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
private final Authentication clientPrincipal;
private final String authorizationUri;
private final @Nullable String authorizationUri;
private final Set<String> scopes;
private final OAuth2DeviceCode deviceCode;
private final @Nullable OAuth2DeviceCode deviceCode;
private final OAuth2UserCode userCode;
private final @Nullable OAuth2UserCode userCode;
private final Map<String, Object> additionalParameters;
@ -116,7 +117,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the authorization {@code URI}.
* @return the authorization {@code URI}
*/
public String getAuthorizationUri() {
public @Nullable String getAuthorizationUri() {
return this.authorizationUri;
}
@ -132,7 +133,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the device code.
* @return the device code
*/
public OAuth2DeviceCode getDeviceCode() {
public @Nullable OAuth2DeviceCode getDeviceCode() {
return this.deviceCode;
}
@ -140,7 +141,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the user code.
* @return the user code
*/
public OAuth2UserCode getUserCode() {
public @Nullable OAuth2UserCode getUserCode() {
return this.userCode;
}

View File

@ -104,6 +104,7 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(deviceCodeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -119,8 +120,8 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
this.logger.trace("Retrieved authorization with device code");
}
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class);
Assert.notNull(deviceCode, "deviceCode cannot be null");
if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
if (!deviceCode.isInvalidated()) {
@ -158,6 +159,9 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
throw new OAuth2AuthenticationException(error);
}
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCode, "userCode cannot be null");
// authorization_pending
// The authorization request is still pending as the end user hasn't
// yet completed the user-interaction steps (Section 3.3). The
@ -193,10 +197,13 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
this.logger.trace("Validated device token request parameters");
}
Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient)
.principal(authorization.getAttribute(Principal.class.getName()))
.principal(principal)
.authorizationServerContext(AuthorizationServerContextHolder.getContext())
.authorization(authorization)
.authorizedScopes(authorization.getAuthorizedScopes())

View File

@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
@ -48,9 +49,8 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -65,7 +65,9 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**
@ -73,15 +75,16 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
}
/**
* Returns the {@link OAuth2AuthorizationConsent authorization consent}.
* @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
*/
@Nullable
public OAuth2AuthorizationConsent getAuthorizationConsent() {
public @Nullable OAuth2AuthorizationConsent getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class);
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.Principal;
import java.util.Base64;
import java.util.Collections;
import java.util.Set;
import java.util.function.Predicate;
@ -115,6 +116,7 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
}
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCode, "userCode cannot be null");
if (!userCode.isActive()) {
if (!userCode.isInvalidated()) {
authorization = OAuth2Authorization.from(authorization).invalidate(userCode.getToken()).build();
@ -137,12 +139,16 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
RegisteredClient registeredClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId());
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
}
Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
if (requestedScopes == null) {
requestedScopes = Collections.emptySet();
}
OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
.with(deviceVerificationAuthentication)
@ -174,7 +180,7 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
}
Set<String> currentAuthorizedScopes = (currentAuthorizationConsent != null)
? currentAuthorizationConsent.getScopes() : null;
? currentAuthorizationConsent.getScopes() : Collections.emptySet();
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerContextHolder.getContext()
.getAuthorizationServerSettings();

View File

@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -46,7 +47,7 @@ public class OAuth2DeviceVerificationAuthenticationToken extends AbstractAuthent
private final Map<String, Object> additionalParameters;
private final String clientId;
private final @Nullable String clientId;
/**
* Constructs an {@code OAuth2DeviceVerificationAuthenticationToken} using the
@ -114,9 +115,9 @@ public class OAuth2DeviceVerificationAuthenticationToken extends AbstractAuthent
/**
* Returns the client identifier.
* @return the client identifier
* @return the client identifier, or {@code null} if not set
*/
public String getClientId() {
public @Nullable String getClientId() {
return this.clientId;
}

View File

@ -74,6 +74,7 @@ public final class OAuth2PushedAuthorizationRequestAuthenticationProvider implem
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(pushedAuthorizationRequestAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");

View File

@ -21,7 +21,8 @@ import java.time.Instant;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -39,9 +40,9 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
@Serial
private static final long serialVersionUID = 7330534287786569644L;
private final String requestUri;
private final @Nullable String requestUri;
private final Instant requestUriExpiresAt;
private final @Nullable Instant requestUriExpiresAt;
/**
* Constructs an {@code OAuth2PushedAuthorizationRequestAuthenticationToken} using the
@ -91,8 +92,7 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
* Returns the {@code request_uri} corresponding to the authorization request posted.
* @return the {@code request_uri} corresponding to the authorization request posted
*/
@Nullable
public String getRequestUri() {
public @Nullable String getRequestUri() {
return this.requestUri;
}
@ -102,8 +102,7 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
* @return the expiration time on or after which the {@code request_uri} MUST NOT be
* accepted
*/
@Nullable
public Instant getRequestUriExpiresAt() {
public @Nullable Instant getRequestUriExpiresAt() {
return this.requestUriExpiresAt;
}

View File

@ -38,11 +38,11 @@ final class OAuth2PushedAuthorizationRequestUri {
private static final StringKeyGenerator DEFAULT_STATE_GENERATOR = new Base64StringKeyGenerator(
Base64.getUrlEncoder());
private String requestUri;
private final String requestUri;
private String state;
private final String state;
private Instant expiresAt;
private final Instant expiresAt;
static OAuth2PushedAuthorizationRequestUri create() {
return create(Instant.now().plusSeconds(300));
@ -50,23 +50,17 @@ final class OAuth2PushedAuthorizationRequestUri {
static OAuth2PushedAuthorizationRequestUri create(Instant expiresAt) {
String state = DEFAULT_STATE_GENERATOR.generateKey();
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = new OAuth2PushedAuthorizationRequestUri();
pushedAuthorizationRequestUri.requestUri = REQUEST_URI_PREFIX + state + REQUEST_URI_DELIMITER
+ expiresAt.toEpochMilli();
pushedAuthorizationRequestUri.state = state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
pushedAuthorizationRequestUri.expiresAt = expiresAt;
return pushedAuthorizationRequestUri;
String requestUri = REQUEST_URI_PREFIX + state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
state = state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
return new OAuth2PushedAuthorizationRequestUri(requestUri, state, expiresAt);
}
static OAuth2PushedAuthorizationRequestUri parse(String requestUri) {
int stateStartIndex = REQUEST_URI_PREFIX.length();
int expiresAtStartIndex = requestUri.indexOf(REQUEST_URI_DELIMITER) + REQUEST_URI_DELIMITER.length();
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = new OAuth2PushedAuthorizationRequestUri();
pushedAuthorizationRequestUri.requestUri = requestUri;
pushedAuthorizationRequestUri.state = requestUri.substring(stateStartIndex);
pushedAuthorizationRequestUri.expiresAt = Instant
.ofEpochMilli(Long.parseLong(requestUri.substring(expiresAtStartIndex)));
return pushedAuthorizationRequestUri;
String state = requestUri.substring(stateStartIndex);
Instant expiresAt = Instant.ofEpochMilli(Long.parseLong(requestUri.substring(expiresAtStartIndex)));
return new OAuth2PushedAuthorizationRequestUri(requestUri, state, expiresAt);
}
String getRequestUri() {
@ -81,7 +75,10 @@ final class OAuth2PushedAuthorizationRequestUri {
return this.expiresAt;
}
private OAuth2PushedAuthorizationRequestUri() {
private OAuth2PushedAuthorizationRequestUri(String requestUri, String state, Instant expiresAt) {
this.requestUri = requestUri;
this.state = state;
this.expiresAt = expiresAt;
}
}

View File

@ -105,6 +105,7 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(refreshTokenAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -137,6 +138,7 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
}
OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
Assert.notNull(refreshToken, "refreshToken cannot be null");
if (!refreshToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code,
@ -168,7 +170,10 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
&& clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
// For public clients, verify the DPoP Proof public key is same as (current)
// access token public key binding
Map<String, Object> accessTokenClaims = authorization.getAccessToken().getClaims();
OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getAccessToken();
Assert.notNull(accessToken, "accessToken cannot be null");
Map<String, Object> accessTokenClaims = (accessToken.getClaims() != null) ? accessToken.getClaims()
: Collections.emptyMap();
verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims);
}
@ -180,10 +185,12 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
scopes = authorizedScopes;
}
Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient)
.principal(authorization.getAttribute(Principal.class.getName()))
.principal(principal)
.authorizationServerContext(AuthorizationServerContextHolder.getContext())
.authorization(authorization)
.authorizedScopes(scopes)

View File

@ -21,7 +21,8 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;

View File

@ -48,11 +48,15 @@ public final class OAuth2TokenExchangeActor implements ClaimAccessor {
}
public String getIssuer() {
return getClaimAsString(OAuth2TokenClaimNames.ISS);
String issuer = getClaimAsString(OAuth2TokenClaimNames.ISS);
Assert.notNull(issuer, "issuer cannot be null");
return issuer;
}
public String getSubject() {
return getClaimAsString(OAuth2TokenClaimNames.SUB);
String subject = getClaimAsString(OAuth2TokenClaimNames.SUB);
Assert.notNull(subject, "subject cannot be null");
return subject;
}
@Override

View File

@ -28,6 +28,7 @@ import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -106,6 +107,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(tokenExchangeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -133,6 +135,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2Authorization.Token<OAuth2Token> subjectToken = subjectAuthorization
.getToken(tokenExchangeAuthentication.getSubjectToken());
Assert.notNull(subjectToken, "subjectToken cannot be null");
if (!subjectToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code,
@ -175,6 +178,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2Authorization.Token<OAuth2Token> actorToken = actorAuthorization
.getToken(tokenExchangeAuthentication.getActorToken());
Assert.notNull(actorToken, "actorToken cannot be null");
if (!actorToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization
@ -184,8 +188,11 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}
if (!isValidTokenType(tokenExchangeAuthentication.getActorTokenType(), actorToken)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
String actorTokenType = tokenExchangeAuthentication.getActorTokenType();
if (actorTokenType != null) {
if (!isValidTokenType(actorTokenType, actorToken)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}
}
if (authorizedActorClaims != null) {
@ -288,7 +295,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
return new LinkedHashSet<>(requestedScopes);
}
private static void validateClaims(Map<String, Object> expectedClaims, Map<String, Object> actualClaims,
private static void validateClaims(Map<String, Object> expectedClaims, @Nullable Map<String, Object> actualClaims,
String... claimNames) {
if (actualClaims == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
@ -302,8 +309,9 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
}
private static Authentication getPrincipal(OAuth2Authorization subjectAuthorization,
OAuth2Authorization actorAuthorization) {
@Nullable OAuth2Authorization actorAuthorization) {
Authentication subjectPrincipal = subjectAuthorization.getAttribute(Principal.class.getName());
Assert.notNull(subjectPrincipal, "subject principal cannot be null");
if (actorAuthorization == null) {
if (subjectPrincipal instanceof OAuth2TokenExchangeCompositeAuthenticationToken compositeAuthenticationToken) {
return compositeAuthenticationToken.getSubject();
@ -312,8 +320,11 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
}
// Capture claims for current actor's access token
OAuth2TokenExchangeActor currentActor = new OAuth2TokenExchangeActor(
actorAuthorization.getAccessToken().getClaims());
OAuth2Authorization.Token<OAuth2AccessToken> actorAccessToken = actorAuthorization.getAccessToken();
Assert.notNull(actorAccessToken, "actor access token cannot be null");
Map<String, Object> actorAccessTokenClaims = actorAccessToken.getClaims();
Assert.notNull(actorAccessTokenClaims, "actor access token claims cannot be null");
OAuth2TokenExchangeActor currentActor = new OAuth2TokenExchangeActor(actorAccessTokenClaims);
List<OAuth2TokenExchangeActor> actorPrincipals = new LinkedList<>();
actorPrincipals.add(currentActor);

View File

@ -22,7 +22,8 @@ import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
@ -43,9 +44,9 @@ public class OAuth2TokenExchangeAuthenticationToken extends OAuth2AuthorizationG
private final String subjectTokenType;
private final String actorToken;
private final @Nullable String actorToken;
private final String actorTokenType;
private final @Nullable String actorTokenType;
private final Set<String> resources;
@ -113,17 +114,17 @@ public class OAuth2TokenExchangeAuthenticationToken extends OAuth2AuthorizationG
/**
* Returns the actor token.
* @return the actor token
* @return the actor token, or {@code null} if not provided
*/
public String getActorToken() {
public @Nullable String getActorToken() {
return this.actorToken;
}
/**
* Returns the actor token type.
* @return the actor token type
* @return the actor token type, or {@code null} if not provided
*/
public String getActorTokenType() {
public @Nullable String getActorTokenType() {
return this.actorTokenType;
}

View File

@ -21,6 +21,8 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.util.Assert;
@ -51,12 +53,12 @@ public class OAuth2TokenExchangeCompositeAuthenticationToken extends AbstractAut
}
@Override
public Object getPrincipal() {
public @Nullable Object getPrincipal() {
return this.subject.getPrincipal();
}
@Override
public Object getCredentials() {
public @Nullable Object getCredentials() {
return null;
}

View File

@ -102,6 +102,7 @@ public final class OAuth2TokenIntrospectionAuthenticationProvider implements Aut
OAuth2Authorization.Token<OAuth2Token> authorizedToken = authorization
.getToken(tokenIntrospectionAuthentication.getToken());
Assert.notNull(authorizedToken, "authorizedToken cannot be null");
if (!authorizedToken.isActive()) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not introspect token since not active");
@ -112,6 +113,7 @@ public final class OAuth2TokenIntrospectionAuthenticationProvider implements Aut
RegisteredClient authorizedClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId());
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
OAuth2TokenIntrospection tokenClaims = withActiveTokenClaims(authorizedToken, authorizedClient);
if (this.logger.isTraceEnabled()) {

View File

@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenIntrospection;
@ -46,7 +47,7 @@ public class OAuth2TokenIntrospectionAuthenticationToken extends AbstractAuthent
private final Authentication clientPrincipal;
private final String tokenTypeHint;
private final @Nullable String tokenTypeHint;
private final Map<String, Object> additionalParameters;
@ -118,8 +119,7 @@ public class OAuth2TokenIntrospectionAuthenticationToken extends AbstractAuthent
* Returns the token type hint.
* @return the token type hint
*/
@Nullable
public String getTokenTypeHint() {
public @Nullable String getTokenTypeHint() {
return this.tokenTypeHint;
}

View File

@ -64,6 +64,7 @@ public final class OAuth2TokenRevocationAuthenticationProvider implements Authen
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
OAuth2Authorization authorization = this.authorizationService
.findByToken(tokenRevocationAuthentication.getToken(), null);
@ -80,6 +81,7 @@ public final class OAuth2TokenRevocationAuthenticationProvider implements Authen
}
OAuth2Authorization.Token<OAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
Assert.notNull(token, "token cannot be null");
authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
this.authorizationService.save(authorization);

View File

@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.io.Serial;
import java.util.Collections;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2Token;
@ -43,7 +44,7 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica
private final Authentication clientPrincipal;
private final String tokenTypeHint;
private final @Nullable String tokenTypeHint;
/**
* Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided
@ -100,8 +101,7 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica
* Returns the token type hint.
* @return the token type hint
*/
@Nullable
public String getTokenTypeHint() {
public @Nullable String getTokenTypeHint() {
return this.tokenTypeHint;
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -70,7 +71,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!ClientAuthenticationMethod.NONE.equals(clientAuthentication.getClientAuthenticationMethod())) {
@ -80,7 +81,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
throw invalidClient(OAuth2ParameterNames.CLIENT_ID);
}
if (this.logger.isTraceEnabled()) {
@ -89,7 +90,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method");
throw invalidClient("authentication_method");
}
if (this.logger.isTraceEnabled()) {
@ -112,10 +113,10 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
}
private static void throwInvalidClient(String parameterName) {
private static OAuth2AuthenticationException invalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error);
return new OAuth2AuthenticationException(error);
}
}

View File

@ -21,6 +21,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -79,7 +80,7 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!ClientAuthenticationMethod.TLS_CLIENT_AUTH.equals(clientAuthentication.getClientAuthenticationMethod())
@ -91,7 +92,7 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
throw invalidClient(OAuth2ParameterNames.CLIENT_ID);
}
if (this.logger.isTraceEnabled()) {
@ -100,11 +101,11 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method");
throw invalidClient("authentication_method");
}
if (!(clientAuthentication.getCredentials() instanceof X509Certificate[])) {
throwInvalidClient("credentials");
throw invalidClient("credentials");
}
OAuth2ClientAuthenticationContext authenticationContext = OAuth2ClientAuthenticationContext
@ -170,22 +171,23 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
Assert.notEmpty(clientCertificateChain, "clientCertificateChain cannot be empty");
X509Certificate clientCertificate = clientCertificateChain[0];
String expectedSubjectDN = registeredClient.getClientSettings().getX509CertificateSubjectDN();
if (!StringUtils.hasText(expectedSubjectDN)
|| !clientCertificate.getSubjectX500Principal().getName().equals(expectedSubjectDN)) {
throwInvalidClient("x509_certificate_subject_dn");
throw invalidClient("x509_certificate_subject_dn");
}
}
private static void throwInvalidClient(String parameterName) {
throwInvalidClient(parameterName, null);
private static OAuth2AuthenticationException invalidClient(String parameterName) {
return invalidClient(parameterName, null);
}
private static void throwInvalidClient(String parameterName, Throwable cause) {
private static OAuth2AuthenticationException invalidClient(String parameterName, @Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause);
return new OAuth2AuthenticationException(error, error.toString(), cause);
}
}

View File

@ -37,6 +37,7 @@ import javax.security.auth.x500.X500Principal;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSet;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -48,6 +49,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
@ -74,12 +76,13 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
Assert.notEmpty(clientCertificateChain, "clientCertificateChain cannot be empty");
X509Certificate clientCertificate = clientCertificateChain[0];
X500Principal issuer = clientCertificate.getIssuerX500Principal();
X500Principal subject = clientCertificate.getSubjectX500Principal();
if (issuer == null || !issuer.equals(subject)) {
throwInvalidClient("x509_certificate_issuer");
throw invalidClient("x509_certificate_issuer");
}
JWKSet jwkSet = this.jwkSetSupplier.apply(registeredClient);
@ -95,18 +98,18 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
}
if (!publicKeyMatches) {
throwInvalidClient("x509_certificate");
throw invalidClient("x509_certificate");
}
}
private static void throwInvalidClient(String parameterName) {
throwInvalidClient(parameterName, null);
private static OAuth2AuthenticationException invalidClient(String parameterName) {
return invalidClient(parameterName, null);
}
private static void throwInvalidClient(String parameterName, Throwable cause) {
private static OAuth2AuthenticationException invalidClient(String parameterName, @Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause);
return new OAuth2AuthenticationException(error, error.toString(), cause);
}
private static final class JwkSetSupplier implements Function<RegisteredClient, JWKSet> {
@ -128,7 +131,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
public JWKSet apply(RegisteredClient registeredClient) {
Supplier<JWKSet> jwkSetSupplier = this.jwkSets.computeIfAbsent(registeredClient.getId(), (key) -> {
if (!StringUtils.hasText(registeredClient.getClientSettings().getJwkSetUrl())) {
throwInvalidClient("client_jwk_set_url");
throw invalidClient("client_jwk_set_url");
}
return new JwkSetHolder(registeredClient.getClientSettings().getJwkSetUrl());
});
@ -136,34 +139,36 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
}
private JWKSet retrieve(String jwkSetUrl) {
URI jwkSetUri = null;
final URI jwkSetUri;
try {
jwkSetUri = new URI(jwkSetUrl);
}
catch (URISyntaxException ex) {
throwInvalidClient("jwk_set_uri", ex);
throw invalidClient("jwk_set_uri", ex);
}
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, jwkSetUri);
ResponseEntity<String> response = null;
final ResponseEntity<String> response;
try {
response = this.restOperations.exchange(request, String.class);
}
catch (Exception ex) {
throwInvalidClient("jwk_set_response_error", ex);
throw invalidClient("jwk_set_response_error", ex);
}
if (response.getStatusCode().value() != 200) {
throwInvalidClient("jwk_set_response_status");
throw invalidClient("jwk_set_response_status");
}
JWKSet jwkSet = null;
final JWKSet jwkSet;
try {
jwkSet = JWKSet.parse(response.getBody());
String body = response.getBody();
Assert.notNull(body, "response body cannot be null");
jwkSet = JWKSet.parse(body);
}
catch (ParseException ex) {
throwInvalidClient("jwk_set_response_body", ex);
throw invalidClient("jwk_set_response_body", ex);
}
return jwkSet;
@ -177,9 +182,9 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
private final String jwkSetUrl;
private JWKSet jwkSet;
private @Nullable JWKSet jwkSet;
private Instant lastUpdatedAt;
private @Nullable Instant lastUpdatedAt;
private JwkSetHolder(String jwkSetUrl) {
this.jwkSetUrl = jwkSetUrl;
@ -204,6 +209,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
}
try {
Assert.notNull(this.jwkSet, "jwkSet cannot be null");
return this.jwkSet;
}
finally {
@ -213,7 +219,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
private boolean shouldRefresh() {
// Refresh every 5 minutes
return (this.jwkSet == null
return (this.jwkSet == null || this.lastUpdatedAt == null
|| this.clock.instant().isAfter(this.lastUpdatedAt.plus(5, ChronoUnit.MINUTES)));
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* {@link org.springframework.security.authentication.AuthenticationProvider}
* implementations and related types for OAuth2 and OpenID Connect 1.0 flows handled by
* the authorization server.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.authentication;
import org.jspecify.annotations.NullMarked;

View File

@ -21,7 +21,8 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -83,16 +84,14 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien
this.clientIdRegistrationMap.put(registeredClient.getClientId(), registeredClient);
}
@Nullable
@Override
public RegisteredClient findById(String id) {
public @Nullable RegisteredClient findById(String id) {
Assert.hasText(id, "id cannot be empty");
return this.idRegistrationMap.get(id);
}
@Nullable
@Override
public RegisteredClient findByClientId(String clientId) {
public @Nullable RegisteredClient findByClientId(String clientId) {
Assert.hasText(clientId, "clientId cannot be empty");
return this.clientIdRegistrationMap.get(clientId);
}

View File

@ -31,6 +31,7 @@ import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import tools.jackson.databind.JacksonModule;
import tools.jackson.databind.json.JsonMapper;
@ -190,18 +191,18 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
}
@Override
public RegisteredClient findById(String id) {
public @Nullable RegisteredClient findById(String id) {
Assert.hasText(id, "id cannot be empty");
return findBy("id = ?", id);
}
@Override
public RegisteredClient findByClientId(String clientId) {
public @Nullable RegisteredClient findByClientId(String clientId) {
Assert.hasText(clientId, "clientId cannot be empty");
return findBy("client_id = ?", clientId);
}
private RegisteredClient findBy(String filter, Object... args) {
private @Nullable RegisteredClient findBy(String filter, Object... args) {
List<RegisteredClient> result = this.jdbcOperations.query(LOAD_REGISTERED_CLIENT_SQL + filter,
this.registeredClientRowMapper, args);
return !result.isEmpty() ? result.get(0) : null;
@ -334,10 +335,15 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
// @formatter:off
RegisteredClient.Builder builder = RegisteredClient.withId(rs.getString("id"))
.clientId(rs.getString("client_id"))
.clientIdIssuedAt((clientIdIssuedAt != null) ? clientIdIssuedAt.toInstant() : null)
.clientSecret(rs.getString("client_secret"))
.clientSecretExpiresAt((clientSecretExpiresAt != null) ? clientSecretExpiresAt.toInstant() : null)
.clientId(rs.getString("client_id"));
if (clientIdIssuedAt != null) {
builder.clientIdIssuedAt(clientIdIssuedAt.toInstant());
}
builder.clientSecret(rs.getString("client_secret"));
if (clientSecretExpiresAt != null) {
builder.clientSecretExpiresAt(clientSecretExpiresAt.toInstant());
}
builder
.clientName(rs.getString("client_name"))
.clientAuthenticationMethods((authenticationMethods) ->
clientAuthenticationMethods.forEach((authenticationMethod) ->
@ -558,7 +564,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
static class JdbcRegisteredClientRepositoryRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources()
.registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql"));

View File

@ -27,7 +27,8 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
@ -50,31 +51,31 @@ public class RegisteredClient implements Serializable {
@Serial
private static final long serialVersionUID = -717282636175335081L;
private String id;
private @Nullable String id;
private String clientId;
private @Nullable String clientId;
private Instant clientIdIssuedAt;
private @Nullable Instant clientIdIssuedAt;
private String clientSecret;
private @Nullable String clientSecret;
private Instant clientSecretExpiresAt;
private @Nullable Instant clientSecretExpiresAt;
private String clientName;
private @Nullable String clientName;
private Set<ClientAuthenticationMethod> clientAuthenticationMethods;
private @Nullable Set<ClientAuthenticationMethod> clientAuthenticationMethods;
private Set<AuthorizationGrantType> authorizationGrantTypes;
private @Nullable Set<AuthorizationGrantType> authorizationGrantTypes;
private Set<String> redirectUris;
private @Nullable Set<String> redirectUris;
private Set<String> postLogoutRedirectUris;
private @Nullable Set<String> postLogoutRedirectUris;
private Set<String> scopes;
private @Nullable Set<String> scopes;
private ClientSettings clientSettings;
private @Nullable ClientSettings clientSettings;
private TokenSettings tokenSettings;
private @Nullable TokenSettings tokenSettings;
protected RegisteredClient() {
}
@ -84,6 +85,7 @@ public class RegisteredClient implements Serializable {
* @return the identifier for the registration
*/
public String getId() {
Assert.notNull(this.id, "id cannot be null");
return this.id;
}
@ -92,6 +94,7 @@ public class RegisteredClient implements Serializable {
* @return the client identifier
*/
public String getClientId() {
Assert.notNull(this.clientId, "clientId cannot be null");
return this.clientId;
}
@ -99,8 +102,7 @@ public class RegisteredClient implements Serializable {
* Returns the time at which the client identifier was issued.
* @return the time at which the client identifier was issued
*/
@Nullable
public Instant getClientIdIssuedAt() {
public @Nullable Instant getClientIdIssuedAt() {
return this.clientIdIssuedAt;
}
@ -108,8 +110,7 @@ public class RegisteredClient implements Serializable {
* Returns the client secret or {@code null} if not available.
* @return the client secret or {@code null} if not available
*/
@Nullable
public String getClientSecret() {
public @Nullable String getClientSecret() {
return this.clientSecret;
}
@ -119,8 +120,7 @@ public class RegisteredClient implements Serializable {
* @return the time at which the client secret expires or {@code null} if it does not
* expire
*/
@Nullable
public Instant getClientSecretExpiresAt() {
public @Nullable Instant getClientSecretExpiresAt() {
return this.clientSecretExpiresAt;
}
@ -129,6 +129,7 @@ public class RegisteredClient implements Serializable {
* @return the client name
*/
public String getClientName() {
Assert.notNull(this.clientName, "clientName cannot be null");
return this.clientName;
}
@ -139,6 +140,7 @@ public class RegisteredClient implements Serializable {
* method(s)}
*/
public Set<ClientAuthenticationMethod> getClientAuthenticationMethods() {
Assert.notNull(this.clientAuthenticationMethods, "clientAuthenticationMethods cannot be null");
return this.clientAuthenticationMethods;
}
@ -149,6 +151,7 @@ public class RegisteredClient implements Serializable {
* type(s)}
*/
public Set<AuthorizationGrantType> getAuthorizationGrantTypes() {
Assert.notNull(this.authorizationGrantTypes, "authorizationGrantTypes cannot be null");
return this.authorizationGrantTypes;
}
@ -157,6 +160,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of redirect URI(s)
*/
public Set<String> getRedirectUris() {
Assert.notNull(this.redirectUris, "redirectUris cannot be null");
return this.redirectUris;
}
@ -167,6 +171,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of post logout redirect URI(s)
*/
public Set<String> getPostLogoutRedirectUris() {
Assert.notNull(this.postLogoutRedirectUris, "postLogoutRedirectUris cannot be null");
return this.postLogoutRedirectUris;
}
@ -175,6 +180,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of scope(s)
*/
public Set<String> getScopes() {
Assert.notNull(this.scopes, "scopes cannot be null");
return this.scopes;
}
@ -183,6 +189,7 @@ public class RegisteredClient implements Serializable {
* @return the {@link ClientSettings}
*/
public ClientSettings getClientSettings() {
Assert.notNull(this.clientSettings, "clientSettings cannot be null");
return this.clientSettings;
}
@ -191,6 +198,7 @@ public class RegisteredClient implements Serializable {
* @return the {@link TokenSettings}
*/
public TokenSettings getTokenSettings() {
Assert.notNull(this.tokenSettings, "tokenSettings cannot be null");
return this.tokenSettings;
}
@ -261,17 +269,17 @@ public class RegisteredClient implements Serializable {
*/
public static class Builder {
private String id;
private @Nullable String id;
private String clientId;
private @Nullable String clientId;
private Instant clientIdIssuedAt;
private @Nullable Instant clientIdIssuedAt;
private String clientSecret;
private @Nullable String clientSecret;
private Instant clientSecretExpiresAt;
private @Nullable Instant clientSecretExpiresAt;
private String clientName;
private @Nullable String clientName;
private final Set<ClientAuthenticationMethod> clientAuthenticationMethods = new HashSet<>();
@ -283,9 +291,9 @@ public class RegisteredClient implements Serializable {
private final Set<String> scopes = new HashSet<>();
private ClientSettings clientSettings;
private @Nullable ClientSettings clientSettings;
private TokenSettings tokenSettings;
private @Nullable TokenSettings tokenSettings;
protected Builder(String id) {
this.id = id;

View File

@ -16,7 +16,7 @@
package org.springframework.security.oauth2.server.authorization.client;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
/**
* A repository for OAuth 2.0 {@link RegisteredClient}(s).
@ -45,8 +45,7 @@ public interface RegisteredClientRepository {
* @param id the registration identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null}
*/
@Nullable
RegisteredClient findById(String id);
@Nullable RegisteredClient findById(String id);
/**
* Returns the registered client identified by the provided {@code clientId}, or
@ -54,7 +53,6 @@ public interface RegisteredClientRepository {
* @param clientId the client identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null}
*/
@Nullable
RegisteredClient findByClientId(String clientId);
@Nullable RegisteredClient findByClientId(String clientId);
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Client registration persistence for the authorization server, including
* {@link org.springframework.security.oauth2.server.authorization.client.RegisteredClient}
* and repository abstractions.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.client;
import org.jspecify.annotations.NullMarked;

View File

@ -16,7 +16,8 @@
package org.springframework.security.oauth2.server.authorization.context;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert;
/**
@ -34,8 +35,7 @@ public interface Context {
* @return the value of the attribute associated to the key, or {@code null} if not
* available
*/
@Nullable
<V> V get(Object key);
<V> @Nullable V get(Object key);
/**
* Returns the value of the attribute associated to the key.
@ -44,8 +44,7 @@ public interface Context {
* @return the value of the attribute associated to the key, or {@code null} if not
* available or not of the specified type
*/
@Nullable
default <V> V get(Class<V> key) {
default <V> @Nullable V get(Class<V> key) {
Assert.notNull(key, "key cannot be null");
V value = get((Object) key);
return key.isInstance(value) ? value : null;

View File

@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Context types that carry authorization server request state and attributes during
* protocol processing.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.context;
import org.jspecify.annotations.NullMarked;

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.converter;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.UUID;
import java.util.function.Consumer;
@ -58,9 +59,11 @@ public final class OAuth2ClientRegistrationRegisteredClientConverter
// @formatter:off
RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString())
.clientId(CLIENT_ID_GENERATOR.generateKey())
.clientIdIssuedAt(Instant.now())
.clientName(clientRegistration.getClientName());
.clientIdIssuedAt(Instant.now());
String clientName = clientRegistration.getClientName();
if (clientName != null) {
builder.clientName(clientName);
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
builder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
@ -80,9 +83,10 @@ public final class OAuth2ClientRegistrationRegisteredClientConverter
redirectUris.addAll(clientRegistration.getRedirectUris()));
}
if (!CollectionUtils.isEmpty(clientRegistration.getGrantTypes())) {
List<String> grantTypes = clientRegistration.getGrantTypes();
if (!CollectionUtils.isEmpty(grantTypes)) {
builder.authorizationGrantTypes((authorizationGrantTypes) ->
clientRegistration.getGrantTypes().forEach((grantType) ->
grantTypes.forEach((grantType) ->
authorizationGrantTypes.add(new AuthorizationGrantType(grantType))));
}
else {

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.converter;
import java.time.Instant;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
@ -39,8 +41,11 @@ public final class RegisteredClientOAuth2ClientRegistrationConverter
// @formatter:off
OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder()
.clientId(registeredClient.getClientId())
.clientIdIssuedAt(registeredClient.getClientIdIssuedAt())
.clientName(registeredClient.getClientName());
Instant clientIdIssuedAt = registeredClient.getClientIdIssuedAt();
if (clientIdIssuedAt != null) {
builder.clientIdIssuedAt(clientIdIssuedAt);
}
builder
.tokenEndpointAuthenticationMethod(registeredClient.getClientAuthenticationMethods().iterator().next().getValue());

View File

@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* {@link org.springframework.core.convert.converter.Converter} implementations for
* authorization server domain types.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.converter;
import org.jspecify.annotations.NullMarked;

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.http.converter;
import org.jspecify.annotations.Nullable;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.json.GsonHttpMessageConverter;
@ -54,7 +56,7 @@ final class HttpMessageConverters {
}
@SuppressWarnings("removal")
static GenericHttpMessageConverter<Object> getJsonMessageConverter() {
static @Nullable GenericHttpMessageConverter<Object> getJsonMessageConverter() {
if (jacksonPresent) {
return new GenericHttpMessageConverterAdapter<>(new JacksonJsonHttpMessageConverter());
}

View File

@ -53,8 +53,7 @@ public class OAuth2AuthorizationServerMetadataHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
};
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters
.getJsonMessageConverter();
private final GenericHttpMessageConverter<Object> jsonMessageConverter;
private Converter<Map<String, Object>, OAuth2AuthorizationServerMetadata> authorizationServerMetadataConverter = new OAuth2AuthorizationServerMetadataConverter();
@ -62,6 +61,9 @@ public class OAuth2AuthorizationServerMetadataHttpMessageConverter
public OAuth2AuthorizationServerMetadataHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
}
@Override

View File

@ -26,6 +26,8 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.Converter;
@ -60,8 +62,7 @@ public class OAuth2ClientRegistrationHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
};
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters
.getJsonMessageConverter();
private final GenericHttpMessageConverter<Object> jsonMessageConverter;
private Converter<Map<String, Object>, OAuth2ClientRegistration> clientRegistrationConverter = new MapOAuth2ClientRegistrationConverter();
@ -69,6 +70,9 @@ public class OAuth2ClientRegistrationHttpMessageConverter
public OAuth2ClientRegistrationHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
}
@Override
@ -187,7 +191,7 @@ public class OAuth2ClientRegistrationHttpMessageConverter
return (source) -> CLAIM_CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor);
}
private static Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) {
private static @Nullable Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) {
if (clientSecretExpiresAt != null && String.valueOf(clientSecretExpiresAt).equals("0")) {
// 0 indicates that client_secret_expires_at does not expire
return null;

View File

@ -61,8 +61,7 @@ public class OAuth2TokenIntrospectionHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
};
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters
.getJsonMessageConverter();
private final GenericHttpMessageConverter<Object> jsonMessageConverter;
private Converter<Map<String, Object>, OAuth2TokenIntrospection> tokenIntrospectionConverter = new MapOAuth2TokenIntrospectionConverter();
@ -70,6 +69,9 @@ public class OAuth2TokenIntrospectionHttpMessageConverter
public OAuth2TokenIntrospectionHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
}
@Override

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* HTTP message converters for OAuth2 Authorization Server protocol representations.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.http.converter;
import org.jspecify.annotations.NullMarked;

View File

@ -19,6 +19,7 @@ package org.springframework.security.oauth2.server.authorization.jackson;
import java.util.Map;
import java.util.Set;
import org.jspecify.annotations.Nullable;
import tools.jackson.core.type.TypeReference;
import tools.jackson.databind.DeserializationContext;
import tools.jackson.databind.JsonNode;
@ -37,7 +38,7 @@ abstract class JsonNodeUtils {
static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() {
};
static String findStringValue(JsonNode jsonNode, String fieldName) {
static @Nullable String findStringValue(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) {
return null;
}
@ -45,7 +46,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isString()) ? value.stringValue() : null;
}
static <T> T findValue(JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
static <T> @Nullable T findValue(@Nullable JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
DeserializationContext context) {
if (jsonNode == null) {
return null;
@ -55,7 +56,7 @@ abstract class JsonNodeUtils {
? context.readTreeAsValue(value, context.getTypeFactory().constructType(valueTypeReference)) : null;
}
static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) {
static @Nullable JsonNode findObjectNode(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) {
return null;
}

View File

@ -16,6 +16,10 @@
package org.springframework.security.oauth2.server.authorization.jackson;
import java.util.Collections;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import tools.jackson.core.JsonParser;
import tools.jackson.databind.DeserializationContext;
import tools.jackson.databind.JsonNode;
@ -25,6 +29,7 @@ import tools.jackson.databind.exc.InvalidFormatException;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder;
import org.springframework.util.Assert;
/**
* A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}.
@ -45,16 +50,27 @@ final class OAuth2AuthorizationRequestDeserializer extends ValueDeserializer<OAu
private OAuth2AuthorizationRequest deserialize(JsonParser parser, DeserializationContext context, JsonNode root) {
AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType(
JsonNodeUtils.findObjectNode(root, "authorizationGrantType"));
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
Builder builder = getBuilder(parser, authorizationGrantType);
builder.authorizationUri(JsonNodeUtils.findStringValue(root, "authorizationUri"));
builder.clientId(JsonNodeUtils.findStringValue(root, "clientId"));
String authorizationUri = JsonNodeUtils.findStringValue(root, "authorizationUri");
Assert.notNull(authorizationUri, "authorizationUri cannot be null");
builder.authorizationUri(authorizationUri);
String clientId = JsonNodeUtils.findStringValue(root, "clientId");
Assert.notNull(clientId, "clientId cannot be null");
builder.clientId(clientId);
builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri"));
builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, context));
builder.state(JsonNodeUtils.findStringValue(root, "state"));
builder.additionalParameters(
JsonNodeUtils.findValue(root, "additionalParameters", JsonNodeUtils.STRING_OBJECT_MAP, context));
builder.authorizationRequestUri(JsonNodeUtils.findStringValue(root, "authorizationRequestUri"));
builder.attributes(JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP, context));
Map<String, Object> additionalParameters = JsonNodeUtils.findValue(root, "additionalParameters",
JsonNodeUtils.STRING_OBJECT_MAP, context);
builder.additionalParameters((additionalParameters != null) ? additionalParameters : Collections.emptyMap());
String authorizationRequestUri = JsonNodeUtils.findStringValue(root, "authorizationRequestUri");
if (authorizationRequestUri != null) {
builder.authorizationRequestUri(authorizationRequestUri);
}
Map<String, Object> attributes = JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP,
context);
builder.attributes((attributes != null) ? attributes : Collections.emptyMap());
return builder.build();
}
@ -66,7 +82,10 @@ final class OAuth2AuthorizationRequestDeserializer extends ValueDeserializer<OAu
AuthorizationGrantType.class);
}
private static AuthorizationGrantType convertAuthorizationGrantType(JsonNode jsonNode) {
private static @Nullable AuthorizationGrantType convertAuthorizationGrantType(@Nullable JsonNode jsonNode) {
if (jsonNode == null) {
return null;
}
String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) {
return AuthorizationGrantType.AUTHORIZATION_CODE;

View File

@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Jackson 3 ({@code tools.jackson}) serialization support for authorization server types.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.jackson;
import org.jspecify.annotations.NullMarked;

View File

@ -22,6 +22,7 @@ import java.util.Set;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
/**
* Utility class for {@code JsonNode}.
@ -41,7 +42,7 @@ abstract class JsonNodeUtils {
static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() {
};
static String findStringValue(JsonNode jsonNode, String fieldName) {
static @Nullable String findStringValue(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) {
return null;
}
@ -49,7 +50,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isTextual()) ? value.asText() : null;
}
static <T> T findValue(JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
static <T> @Nullable T findValue(@Nullable JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
ObjectMapper mapper) {
if (jsonNode == null) {
return null;
@ -58,7 +59,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isContainerNode()) ? mapper.convertValue(value, valueTypeReference) : null;
}
static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) {
static @Nullable JsonNode findObjectNode(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) {
return null;
}

View File

@ -17,6 +17,7 @@
package org.springframework.security.oauth2.server.authorization.jackson2;
import java.io.IOException;
import java.util.Map;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
@ -24,10 +25,12 @@ import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder;
import org.springframework.util.Assert;
/**
* A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}.
@ -57,27 +60,42 @@ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer<OAut
AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType(
JsonNodeUtils.findObjectNode(root, "authorizationGrantType"));
Builder builder = getBuilder(parser, authorizationGrantType);
builder.authorizationUri(JsonNodeUtils.findStringValue(root, "authorizationUri"));
builder.clientId(JsonNodeUtils.findStringValue(root, "clientId"));
String authorizationUri = JsonNodeUtils.findStringValue(root, "authorizationUri");
Assert.notNull(authorizationUri, "authorizationUri cannot be null");
builder.authorizationUri(authorizationUri);
String clientId = JsonNodeUtils.findStringValue(root, "clientId");
Assert.notNull(clientId, "clientId cannot be null");
builder.clientId(clientId);
builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri"));
builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, mapper));
builder.state(JsonNodeUtils.findStringValue(root, "state"));
builder.additionalParameters(
JsonNodeUtils.findValue(root, "additionalParameters", JsonNodeUtils.STRING_OBJECT_MAP, mapper));
builder.authorizationRequestUri(JsonNodeUtils.findStringValue(root, "authorizationRequestUri"));
builder.attributes(JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP, mapper));
Map<String, Object> additionalParameters = JsonNodeUtils.findValue(root, "additionalParameters",
JsonNodeUtils.STRING_OBJECT_MAP, mapper);
if (additionalParameters != null) {
builder.additionalParameters(additionalParameters);
}
String authorizationRequestUri = JsonNodeUtils.findStringValue(root, "authorizationRequestUri");
if (authorizationRequestUri != null) {
builder.authorizationRequestUri(authorizationRequestUri);
}
Map<String, Object> attributes = JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP,
mapper);
if (attributes != null) {
builder.attributes(attributes);
}
return builder.build();
}
private Builder getBuilder(JsonParser parser, AuthorizationGrantType authorizationGrantType)
private Builder getBuilder(JsonParser parser, @Nullable AuthorizationGrantType authorizationGrantType)
throws JsonParseException {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) {
if (authorizationGrantType != null
&& authorizationGrantType.equals(AuthorizationGrantType.AUTHORIZATION_CODE)) {
return OAuth2AuthorizationRequest.authorizationCode();
}
throw new JsonParseException(parser, "Invalid authorizationGrantType");
}
private static AuthorizationGrantType convertAuthorizationGrantType(JsonNode jsonNode) {
private static @Nullable AuthorizationGrantType convertAuthorizationGrantType(@Nullable JsonNode jsonNode) {
String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) {
return AuthorizationGrantType.AUTHORIZATION_CODE;

View File

@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Jackson 2 ({@code com.fasterxml.jackson}) serialization support for authorization
* server types (deprecated in favor of {@code jackson}).
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.jackson2;
import org.jspecify.annotations.NullMarked;

View File

@ -19,6 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc;
import java.net.URL;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
@ -53,7 +55,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* to after a logout has been performed.
* @return the post logout redirection {@code URI} values used by the Client
*/
default List<String> getPostLogoutRedirectUris() {
default @Nullable List<String> getPostLogoutRedirectUris() {
return getClaimAsStringList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS);
}
@ -66,7 +68,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@link JwsAlgorithm JWS} algorithm that must be used for signing the
* {@link Jwt JWT} used to authenticate the Client at the Token Endpoint
*/
default String getTokenEndpointAuthenticationSigningAlgorithm() {
default @Nullable String getTokenEndpointAuthenticationSigningAlgorithm() {
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG);
}
@ -77,7 +79,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@link SignatureAlgorithm JWS} algorithm required for signing the
* {@link OidcIdToken ID Token} issued to the Client
*/
default String getIdTokenSignedResponseAlgorithm() {
default @Nullable String getIdTokenSignedResponseAlgorithm() {
return getClaimAsString(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG);
}
@ -87,7 +89,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the Registration Access Token that can be used at the Client Configuration
* Endpoint
*/
default String getRegistrationAccessToken() {
default @Nullable String getRegistrationAccessToken() {
return getClaimAsString(OidcClientMetadataClaimNames.REGISTRATION_ACCESS_TOKEN);
}
@ -97,7 +99,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@code URL} of the Client Configuration Endpoint where the Registration
* Access Token can be used
*/
default URL getRegistrationClientUrl() {
default @Nullable URL getRegistrationClientUrl() {
return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI);
}

View File

@ -19,11 +19,14 @@ package org.springframework.security.oauth2.server.authorization.oidc;
import java.net.URL;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadataClaimAccessor;
import org.springframework.util.Assert;
/**
* A {@link ClaimAccessor} for the "claims" that can be returned in the OpenID Provider
@ -47,7 +50,9 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* @return the Subject Identifier types supported
*/
default List<String> getSubjectTypes() {
return getClaimAsStringList(OidcProviderMetadataClaimNames.SUBJECT_TYPES_SUPPORTED);
List<String> subjectTypes = getClaimAsStringList(OidcProviderMetadataClaimNames.SUBJECT_TYPES_SUPPORTED);
Assert.notNull(subjectTypes, "subjectTypes cannot be null");
return subjectTypes;
}
/**
@ -58,7 +63,10 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* {@link OidcIdToken ID Token}
*/
default List<String> getIdTokenSigningAlgorithms() {
return getClaimAsStringList(OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED);
List<String> idTokenSigningAlgorithms = getClaimAsStringList(
OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED);
Assert.notNull(idTokenSigningAlgorithms, "idTokenSigningAlgorithms cannot be null");
return idTokenSigningAlgorithms;
}
/**
@ -66,7 +74,7 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* {@code (userinfo_endpoint)}.
* @return the {@code URL} of the OpenID Connect 1.0 UserInfo Endpoint
*/
default URL getUserInfoEndpoint() {
default @Nullable URL getUserInfoEndpoint() {
return getClaimAsURL(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT);
}
@ -76,7 +84,9 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* @return the {@code URL} of the OpenID Connect 1.0 End Session Endpoint
*/
default URL getEndSessionEndpoint() {
return getClaimAsURL(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT);
URL endSessionEndpoint = getClaimAsURL(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT);
Assert.notNull(endSessionEndpoint, "endSessionEndpoint cannot be null");
return endSessionEndpoint;
}
}

View File

@ -18,10 +18,12 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider;
@ -99,7 +101,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication;
if (!StringUtils.hasText(clientRegistrationAuthentication.getClientId())) {
@ -132,6 +134,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
}
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
}
@ -149,8 +152,9 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) {
RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(clientRegistrationAuthentication.getClientId());
String clientId = clientRegistrationAuthentication.getClientId();
Assert.hasText(clientId, "clientId cannot be empty");
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
}
@ -176,9 +180,11 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
@SuppressWarnings("unchecked")
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) {
Map<String, Object> claims = authorizedAccessToken.getClaims();
Assert.notNull(claims, "claims cannot be null");
Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE);
if (claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
}
if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);

View File

@ -27,6 +27,7 @@ import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider;
@ -124,7 +125,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication;
if (clientRegistrationAuthentication.getClientRegistration() == null) {
@ -157,6 +158,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
}
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
}
@ -210,18 +212,24 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) {
if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getRedirectUris())) {
OidcClientRegistration clientRegistrationRequest = clientRegistrationAuthentication.getClientRegistration();
Assert.notNull(clientRegistrationRequest, "clientRegistration cannot be null");
List<String> redirectUris = (clientRegistrationRequest.getRedirectUris() != null)
? clientRegistrationRequest.getRedirectUris() : Collections.emptyList();
if (!isValidRedirectUris(redirectUris)) {
throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI,
OidcClientMetadataClaimNames.REDIRECT_URIS);
}
if (!isValidRedirectUris(
clientRegistrationAuthentication.getClientRegistration().getPostLogoutRedirectUris())) {
List<String> postLogoutRedirectUris = (clientRegistrationRequest.getPostLogoutRedirectUris() != null)
? clientRegistrationRequest.getPostLogoutRedirectUris() : Collections.emptyList();
if (!isValidRedirectUris(postLogoutRedirectUris)) {
throwInvalidClientRegistration("invalid_client_metadata",
OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS);
}
if (!isValidTokenEndpointAuthenticationMethod(clientRegistrationAuthentication.getClientRegistration())) {
if (!isValidTokenEndpointAuthenticationMethod(clientRegistrationRequest)) {
throwInvalidClientRegistration("invalid_client_metadata",
OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
}
@ -230,8 +238,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
this.logger.trace("Validated client registration request parameters");
}
RegisteredClient registeredClient = this.registeredClientConverter
.convert(clientRegistrationAuthentication.getClientRegistration());
RegisteredClient registeredClient = this.registeredClientConverter.convert(clientRegistrationRequest);
if (StringUtils.hasText(registeredClient.getClientSecret())) {
// Encode the client secret
@ -240,8 +247,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
.build();
this.registeredClientRepository.save(updatedRegisteredClient);
if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()
.equals(clientRegistrationAuthentication.getClientRegistration()
.getTokenEndpointAuthenticationMethod())) {
.equals(clientRegistrationRequest.getTokenEndpointAuthenticationMethod())) {
// gh-1344 Return the hashed client_secret
registeredClient = updatedRegisteredClient;
}
@ -257,8 +263,10 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
OAuth2Authorization registeredClientAuthorization = registerAccessToken(registeredClient);
// Invalidate the "initial" access token as it can only be used once
OAuth2Authorization.Token<OAuth2AccessToken> initialAccessToken = authorization.getAccessToken();
Assert.notNull(initialAccessToken, "initialAccessToken cannot be null");
OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken());
.invalidate(initialAccessToken.getToken());
if (authorization.getRefreshToken() != null) {
builder.invalidate(authorization.getRefreshToken().getToken());
}
@ -271,8 +279,11 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
Map<String, Object> clientRegistrationClaims = this.clientRegistrationConverter.convert(registeredClient)
.getClaims();
OAuth2Authorization.Token<OAuth2AccessToken> registrationAccessToken = registeredClientAuthorization
.getAccessToken();
Assert.notNull(registrationAccessToken, "registrationAccessToken cannot be null");
OidcClientRegistration clientRegistration = OidcClientRegistration.withClaims(clientRegistrationClaims)
.registrationAccessToken(registeredClientAuthorization.getAccessToken().getToken().getTokenValue())
.registrationAccessToken(registrationAccessToken.getToken().getTokenValue())
.build();
if (this.logger.isTraceEnabled()) {
@ -338,9 +349,11 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
@SuppressWarnings("unchecked")
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) {
Map<String, Object> claims = authorizedAccessToken.getClaims();
Assert.notNull(claims, "claims cannot be null");
Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE);
if (claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
}
if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);

View File

@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial;
import java.util.Collections;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
@ -44,9 +45,9 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
private final Authentication principal;
private final OidcClientRegistration clientRegistration;
private final @Nullable OidcClientRegistration clientRegistration;
private final String clientId;
private final @Nullable String clientId;
/**
* Constructs an {@code OidcClientRegistrationAuthenticationToken} using the provided
@ -95,7 +96,7 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
* Returns the client registration.
* @return the client registration
*/
public OidcClientRegistration getClientRegistration() {
public @Nullable OidcClientRegistration getClientRegistration() {
return this.clientRegistration;
}
@ -103,8 +104,7 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
* Returns the client identifier.
* @return the client identifier
*/
@Nullable
public String getClientId() {
public @Nullable String getClientId() {
return this.clientId;
}

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
@ -46,9 +47,8 @@ public final class OidcLogoutAuthenticationContext implements OAuth2Authenticati
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -63,7 +63,9 @@ public final class OidcLogoutAuthenticationContext implements OAuth2Authenticati
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
}
/**

View File

@ -26,6 +26,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
@ -99,7 +100,7 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
OAuth2Authorization authorization = this.authorizationService
.findByToken(oidcLogoutAuthentication.getIdTokenHint(), ID_TOKEN_TOKEN_TYPE);
if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
}
if (this.logger.isTraceEnabled()) {
@ -107,13 +108,15 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
}
OAuth2Authorization.Token<OidcIdToken> authorizedIdToken = authorization.getToken(OidcIdToken.class);
Assert.notNull(authorizedIdToken, "authorizedIdToken cannot be null");
if (authorizedIdToken.isInvalidated() || authorizedIdToken.isBeforeUse()) {
// Expired ID Token should be accepted
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
}
RegisteredClient registeredClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId());
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client");
@ -124,11 +127,11 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
// Validate client identity
List<String> audClaim = idToken.getAudience();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
throw createException(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
}
if (StringUtils.hasText(oidcLogoutAuthentication.getClientId())
&& !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
}
OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication)
@ -144,9 +147,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
if (oidcLogoutAuthentication.isPrincipalAuthenticated()) {
Authentication currentUserPrincipal = (Authentication) oidcLogoutAuthentication.getPrincipal();
Authentication authorizedUserPrincipal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(authorizedUserPrincipal, "authorizedUserPrincipal cannot be null");
if (!StringUtils.hasText(idToken.getSubject())
|| !currentUserPrincipal.getName().equals(authorizedUserPrincipal.getName())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.SUB);
throw createException(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.SUB);
}
// Check for active session
@ -166,7 +170,7 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
String sidClaim = idToken.getClaim("sid");
if (!StringUtils.hasText(sidClaim) || !sidClaim.equals(sessionIdHash)) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "sid");
throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "sid");
}
}
}
@ -205,8 +209,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
this.authenticationValidator = authenticationValidator;
}
private SessionInformation findSessionInformation(Authentication principal, String sessionId) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), true);
private @Nullable SessionInformation findSessionInformation(Authentication principal, String sessionId) {
Object sessionPrincipal = principal.getPrincipal();
Assert.notNull(sessionPrincipal, "sessionPrincipal cannot be null");
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(sessionPrincipal, true);
SessionInformation sessionInformation = null;
if (!CollectionUtils.isEmpty(sessions)) {
for (SessionInformation session : sessions) {
@ -219,10 +225,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
return sessionInformation;
}
private static void throwError(String errorCode, String parameterName) {
private static OAuth2AuthenticationException createException(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName,
"https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling");
throw new OAuth2AuthenticationException(error);
return new OAuth2AuthenticationException(error);
}
private static String createHash(String value) throws NoSuchAlgorithmException {

View File

@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial;
import java.util.Collections;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
@ -42,17 +43,17 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
private final String idTokenHint;
private final OidcIdToken idToken;
private final @Nullable OidcIdToken idToken;
private final Authentication principal;
private final String sessionId;
private final @Nullable String sessionId;
private final String clientId;
private final @Nullable String clientId;
private final String postLogoutRedirectUri;
private final @Nullable String postLogoutRedirectUri;
private final String state;
private final @Nullable String state;
/**
* Constructs an {@code OidcLogoutAuthenticationToken} using the provided parameters.
@ -147,8 +148,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the ID Token previously issued by the Provider to the Client.
* @return the ID Token previously issued by the Provider to the Client
*/
@Nullable
public OidcIdToken getIdToken() {
public @Nullable OidcIdToken getIdToken() {
return this.idToken;
}
@ -156,8 +156,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the End-User's current authenticated session identifier with the Provider.
* @return the End-User's current authenticated session identifier with the Provider
*/
@Nullable
public String getSessionId() {
public @Nullable String getSessionId() {
return this.sessionId;
}
@ -165,8 +164,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the client identifier the ID Token was issued to.
* @return the client identifier
*/
@Nullable
public String getClientId() {
public @Nullable String getClientId() {
return this.clientId;
}
@ -176,8 +174,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* @return the URI which the Client is requesting that the End-User's User Agent be
* redirected to after a logout has been performed
*/
@Nullable
public String getPostLogoutRedirectUri() {
public @Nullable String getPostLogoutRedirectUri() {
return this.postLogoutRedirectUri;
}
@ -187,8 +184,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* @return the opaque value used by the Client to maintain state between the logout
* request and the callback to the {@link #getPostLogoutRedirectUri()}
*/
@Nullable
public String getState() {
public @Nullable String getState() {
return this.state;
}

View File

@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import org.springframework.lang.Nullable;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
@ -48,9 +49,8 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
}
@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}
@ -65,7 +65,9 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
* @return the {@link OAuth2AccessToken}
*/
public OAuth2AccessToken getAccessToken() {
return get(OAuth2AccessToken.class);
OAuth2AccessToken accessToken = get(OAuth2AccessToken.class);
Assert.notNull(accessToken, "accessToken cannot be null");
return accessToken;
}
/**
@ -73,7 +75,9 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
}
/**

View File

@ -98,6 +98,7 @@ public final class OidcUserInfoAuthenticationProvider implements AuthenticationP
}
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
}
@ -191,7 +192,9 @@ public final class OidcUserInfoAuthenticationProvider implements AuthenticationP
@Override
public OidcUserInfo apply(OidcUserInfoAuthenticationContext authenticationContext) {
OAuth2Authorization authorization = authenticationContext.getAuthorization();
OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken();
OAuth2Authorization.Token<OidcIdToken> authorizedIdToken = authorization.getToken(OidcIdToken.class);
Assert.notNull(authorizedIdToken, "authorizedIdToken cannot be null");
OidcIdToken idToken = authorizedIdToken.getToken();
OAuth2AccessToken accessToken = authenticationContext.getAccessToken();
Map<String, Object> scopeRequestedClaims = getClaimsRequestedByScope(idToken.getClaims(),
accessToken.getScopes());

View File

@ -19,6 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial;
import java.util.Collections;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
@ -40,7 +42,7 @@ public class OidcUserInfoAuthenticationToken extends AbstractAuthenticationToken
private final Authentication principal;
private final OidcUserInfo userInfo;
private final @Nullable OidcUserInfo userInfo;
/**
* Constructs an {@code OidcUserInfoAuthenticationToken} using the provided
@ -82,9 +84,9 @@ public class OidcUserInfoAuthenticationToken extends AbstractAuthenticationToken
/**
* Returns the UserInfo claims.
* @return the UserInfo claims
* @return the UserInfo claims, or {@code null} if not provided
*/
public OidcUserInfo getUserInfo() {
public @Nullable OidcUserInfo getUserInfo() {
return this.userInfo;
}

Some files were not shown because too many files have changed in this diff Show More