Using collection of strings for groups

This commit is contained in:
Matt Patrick 2024-12-11 18:18:38 -05:00
parent 7d5c38dc82
commit feb8c2730e
2 changed files with 3 additions and 22 deletions

View File

@ -34,7 +34,6 @@ import org.apache.nifi.registry.security.key.Key;
import org.apache.nifi.registry.security.key.KeyService; import org.apache.nifi.registry.security.key.KeyService;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -44,7 +43,6 @@ import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@Service @Service
public class JwtService { public class JwtService {
@ -150,14 +148,7 @@ public class JwtService {
return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null); return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null);
} }
public String generateSignedToken( public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis, Collection<String> groups) throws JwtException {
String identity,
String preferredUsername,
String issuer,
String audience,
long expirationMillis,
Collection<? extends GrantedAuthority> authorities) throws JwtException {
if (identity == null || StringUtils.isEmpty(identity)) { if (identity == null || StringUtils.isEmpty(identity)) {
String errorMessage = "Cannot generate a JWT for a token with an empty identity"; String errorMessage = "Cannot generate a JWT for a token with an empty identity";
errorMessage = issuer != null ? errorMessage + " issued by " + issuer + "." : "."; errorMessage = issuer != null ? errorMessage + " issued by " + issuer + "." : ".";
@ -183,7 +174,7 @@ public class JwtService {
.audience().add(audience).and() .audience().add(audience).and()
.claim(USERNAME_CLAIM, preferredUsername) .claim(USERNAME_CLAIM, preferredUsername)
.claim(KEY_ID_CLAIM, key.getId()) .claim(KEY_ID_CLAIM, key.getId())
.claim(GROUPS_CLAIM, authorities.stream().map(GrantedAuthority::getAuthority).collect(Collectors.toSet())) .claim(GROUPS_CLAIM, groups)
.issuedAt(now.getTime()) .issuedAt(now.getTime())
.expiration(expiration.getTime()) .expiration(expiration.getTime())
.signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact(); .signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact();

View File

@ -21,11 +21,9 @@ import java.net.URI;
import java.net.URL; import java.net.URL;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -41,7 +39,6 @@ import org.apache.nifi.registry.web.security.authentication.jwt.JwtService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
@ -431,15 +428,8 @@ public class StandardOidcIdentityProvider implements OidcIdentityProvider {
final long expiresIn = expiration.getTime() - now.getTimeInMillis(); final long expiresIn = expiration.getTime() - now.getTimeInMillis();
final String issuer = claimsSet.getIssuer().getValue(); final String issuer = claimsSet.getIssuer().getValue();
Set<SimpleGrantedAuthority> authorities = groups != null ? groups.stream().map(
SimpleGrantedAuthority::new).collect(
Collectors.collectingAndThen(
Collectors.toSet(),
Collections::unmodifiableSet
)) : null;
// convert into a nifi jwt for retrieval later // convert into a nifi jwt for retrieval later
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, authorities); return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, groups);
} }
private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException { private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException {