Replace Streams with Loops

First version of replacing streams

fix wwwAuthenticate and codestyle

fix errors in implementation to pass tests

Fix review notes

Remove uneccessary final to align with cb

Short circuit way to authorize

Simplify error message, make code readably

Return error while duplicate key found

Delete check for duplicate, checkstyle issues

Return duplicate error

Fixes gh-7154
This commit is contained in:
kostya05983 2019-08-04 17:11:24 +07:00 committed by Josh Cummings
parent d6d0d89ff8
commit f6c650db47
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
30 changed files with 283 additions and 184 deletions

View File

@ -43,12 +43,11 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories;
import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
/** /**
* Exports the authentication {@link Configuration} * Exports the authentication {@link Configuration}
@ -153,10 +152,7 @@ public class AuthenticationConfiguration {
} }
String beanName; String beanName;
if (beanNamesForType.length > 1) { if (beanNamesForType.length > 1) {
List<String> primaryBeanNames = Arrays.stream(beanNamesForType) List<String> primaryBeanNames = getPrimaryBeanNames(beanNamesForType);
.filter(i -> applicationContext instanceof ConfigurableApplicationContext)
.filter(n -> ((ConfigurableApplicationContext) applicationContext).getBeanFactory().getBeanDefinition(n).isPrimary())
.collect(Collectors.toList());
Assert.isTrue(primaryBeanNames.size() != 0, () -> "Found " + beanNamesForType.length Assert.isTrue(primaryBeanNames.size() != 0, () -> "Found " + beanNamesForType.length
+ " beans for type " + interfaceName + ", but none marked as primary"); + " beans for type " + interfaceName + ", but none marked as primary");
@ -175,6 +171,20 @@ public class AuthenticationConfiguration {
return (T) proxyFactory.getObject(); return (T) proxyFactory.getObject();
} }
private List<String> getPrimaryBeanNames(String[] beanNamesForType) {
List<String> list = new ArrayList<>();
if (!(applicationContext instanceof ConfigurableApplicationContext)) {
return Collections.emptyList();
}
for (String beanName : beanNamesForType) {
if (((ConfigurableApplicationContext) applicationContext).getBeanFactory()
.getBeanDefinition(beanName).isPrimary()) {
list.add(beanName);
}
}
return list;
}
private AuthenticationManager getAuthenticationManagerBean() { private AuthenticationManager getAuthenticationManagerBean() {
return lazyBean(AuthenticationManager.class); return lazyBean(AuthenticationManager.class);
} }

View File

@ -22,7 +22,6 @@ import reactor.core.publisher.Mono;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Stream;
/** /**
* A {@link ReactiveAuthorizationManager} that determines if the current user is * A {@link ReactiveAuthorizationManager} that determines if the current user is
@ -109,9 +108,14 @@ public class AuthorityReactiveAuthorizationManager<T> implements ReactiveAuthori
Assert.notNull(role, "role cannot be null"); Assert.notNull(role, "role cannot be null");
} }
return hasAnyAuthority(Stream.of(roles) return hasAnyAuthority(toNamedRolesArray(roles));
.map(r -> "ROLE_" + r) }
.toArray(String[]::new)
); private static String[] toNamedRolesArray(String... roles) {
String[] result = new String[roles.length];
for (int i=0; i < roles.length; i++) {
result[i] = "ROLE_" + roles[i];
}
return result;
} }
} }

View File

@ -16,8 +16,8 @@
package org.springframework.security.converter; package org.springframework.security.converter;
import java.io.BufferedReader;
import java.io.InputStream; import java.io.InputStream;
import java.io.BufferedReader;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.security.KeyFactory; import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
@ -25,8 +25,8 @@ import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec; import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Base64;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
@ -66,10 +66,13 @@ public class RsaKeyConverters {
Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(PKCS8_PEM_HEADER), Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(PKCS8_PEM_HEADER),
"Key is not in PEM-encoded PKCS#8 format, " + "Key is not in PEM-encoded PKCS#8 format, " +
"please check that the header begins with -----" + PKCS8_PEM_HEADER + "-----"); "please check that the header begins with -----" + PKCS8_PEM_HEADER + "-----");
String base64Encoded = lines.stream() StringBuilder base64Encoded = new StringBuilder();
.filter(RsaKeyConverters::isNotPkcs8Wrapper) for (String line : lines) {
.collect(Collectors.joining()); if (RsaKeyConverters.isNotPkcs8Wrapper(line)) {
byte[] pkcs8 = Base64.getDecoder().decode(base64Encoded); base64Encoded.append(line);
}
}
byte[] pkcs8 = Base64.getDecoder().decode(base64Encoded.toString());
try { try {
return (RSAPrivateKey) keyFactory.generatePrivate( return (RSAPrivateKey) keyFactory.generatePrivate(
@ -97,10 +100,13 @@ public class RsaKeyConverters {
Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(X509_PEM_HEADER), Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(X509_PEM_HEADER),
"Key is not in PEM-encoded X.509 format, " + "Key is not in PEM-encoded X.509 format, " +
"please check that the header begins with -----" + X509_PEM_HEADER + "-----"); "please check that the header begins with -----" + X509_PEM_HEADER + "-----");
String base64Encoded = lines.stream() StringBuilder base64Encoded = new StringBuilder();
.filter(RsaKeyConverters::isNotX509Wrapper) for (String line : lines) {
.collect(Collectors.joining()); if (RsaKeyConverters.isNotX509Wrapper(line)) {
byte[] x509 = Base64.getDecoder().decode(base64Encoded); base64Encoded.append(line);
}
}
byte[] x509 = Base64.getDecoder().decode(base64Encoded.toString());
try { try {
return (RSAPublicKey) keyFactory.generatePublic( return (RSAPublicKey) keyFactory.generatePublic(

View File

@ -19,8 +19,7 @@ package org.springframework.security.core.userdetails;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.function.Function; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -56,7 +55,10 @@ public class MapReactiveUserDetailsService implements ReactiveUserDetailsService
*/ */
public MapReactiveUserDetailsService(Collection<UserDetails> users) { public MapReactiveUserDetailsService(Collection<UserDetails> users) {
Assert.notEmpty(users, "users cannot be null or empty"); Assert.notEmpty(users, "users cannot be null or empty");
this.users = users.stream().collect(Collectors.toConcurrentMap( u -> getKey(u.getUsername()), Function.identity())); this.users = new ConcurrentHashMap<>();
for (UserDetails user : users) {
this.users.put(getKey(user.getUsername()), user);
}
} }
@Override @Override

View File

@ -46,6 +46,13 @@ public class MapReactiveUserDetailsServiceTests {
new MapReactiveUserDetailsService(users); new MapReactiveUserDetailsService(users);
} }
@Test
public void constructorCaseIntensiveKey() {
UserDetails userDetails = User.withUsername("USER").password("password").roles("USER").build();
MapReactiveUserDetailsService userDetailsService = new MapReactiveUserDetailsService(userDetails);
assertThat(userDetailsService.findByUsername("user").block()).isEqualTo(userDetails);
}
@Test @Test
public void findByUsernameWhenFoundThenReturns() { public void findByUsernameWhenFoundThenReturns() {
assertThat((users.findByUsername(USER_DETAILS.getUsername()).block())).isEqualTo(USER_DETAILS); assertThat((users.findByUsername(USER_DETAILS.getUsername()).block())).isEqualTo(USER_DETAILS);

View File

@ -22,7 +22,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects;
/** /**
* An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates * An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates
@ -64,10 +63,12 @@ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2Aut
@Nullable @Nullable
public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null"); Assert.notNull(context, "context cannot be null");
return this.authorizedClientProviders.stream() for (OAuth2AuthorizedClientProvider authorizedClientProvider : authorizedClientProviders) {
.map(authorizedClientProvider -> authorizedClientProvider.authorize(context)) OAuth2AuthorizedClient oauth2AuthorizedClient = authorizedClientProvider.authorize(context);
.filter(Objects::nonNull) if (oauth2AuthorizedClient != null) {
.findFirst() return oauth2AuthorizedClient;
.orElse(null); }
}
return null;
} }
} }

View File

@ -23,11 +23,11 @@ import org.springframework.util.Assert;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.List;
import java.util.LinkedHashMap;
import java.util.ArrayList;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors;
/** /**
* A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of * A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of
@ -286,10 +286,10 @@ public final class OAuth2AuthorizedClientProviderBuilder {
* @return the {@link DelegatingOAuth2AuthorizedClientProvider} * @return the {@link DelegatingOAuth2AuthorizedClientProvider}
*/ */
public OAuth2AuthorizedClientProvider build() { public OAuth2AuthorizedClientProvider build() {
List<OAuth2AuthorizedClientProvider> authorizedClientProviders = List<OAuth2AuthorizedClientProvider> authorizedClientProviders = new ArrayList<>();
this.builders.values().stream() for (Builder builder : this.builders.values()) {
.map(Builder::build) authorizedClientProviders.add(builder.build());
.collect(Collectors.toList()); }
return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
} }

View File

@ -32,7 +32,6 @@ import java.time.Instant;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* An {@link OAuth2TokenValidator} responsible for * An {@link OAuth2TokenValidator} responsible for
@ -137,11 +136,8 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
} }
private static OAuth2Error invalidIdToken(Map<String, Object> invalidClaims) { private static OAuth2Error invalidIdToken(Map<String, Object> invalidClaims) {
String claimsDetail = invalidClaims.entrySet().stream()
.map(it -> it.getKey() + " (" + it.getValue() + ")")
.collect(Collectors.joining(", "));
return new OAuth2Error("invalid_id_token", return new OAuth2Error("invalid_id_token",
"The ID Token contains invalid claims: " + claimsDetail, "The ID Token contains invalid claims: " + invalidClaims,
"https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation"); "https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation");
} }

View File

@ -22,12 +22,7 @@ import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collector;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toConcurrentMap;
/** /**
* A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory. * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory.
@ -62,9 +57,19 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
private static Map<String, ClientRegistration> createRegistrationsMap(List<ClientRegistration> registrations) { private static Map<String, ClientRegistration> createRegistrationsMap(List<ClientRegistration> registrations) {
Assert.notEmpty(registrations, "registrations cannot be empty"); Assert.notEmpty(registrations, "registrations cannot be empty");
Collector<ClientRegistration, ?, ConcurrentMap<String, ClientRegistration>> collector = return toUnmodifiableConcurrentMap(registrations);
toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity()); }
return registrations.stream().collect(collectingAndThen(collector, Collections::unmodifiableMap));
private static Map<String, ClientRegistration> toUnmodifiableConcurrentMap(List<ClientRegistration> registrations) {
ConcurrentHashMap<String, ClientRegistration> result = new ConcurrentHashMap<>();
for (ClientRegistration registration : registrations) {
if (result.containsKey(registration.getRegistrationId())) {
throw new IllegalStateException(String.format("Duplicate key %s",
registration.getRegistrationId()));
}
result.put(registration.getRegistrationId(), registration);
}
return Collections.unmodifiableMap(result);
} }
/** /**

View File

@ -19,8 +19,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -61,11 +60,9 @@ public final class InMemoryReactiveClientRegistrationRepository
*/ */
public InMemoryReactiveClientRegistrationRepository(List<ClientRegistration> registrations) { public InMemoryReactiveClientRegistrationRepository(List<ClientRegistration> registrations) {
Assert.notEmpty(registrations, "registrations cannot be null or empty"); Assert.notEmpty(registrations, "registrations cannot be null or empty");
this.clientIdToClientRegistration = registrations.stream() this.clientIdToClientRegistration = toConcurrentMap(registrations);
.collect(Collectors.toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity()));
} }
@Override @Override
public Mono<ClientRegistration> findByRegistrationId(String registrationId) { public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
return Mono.justOrEmpty(this.clientIdToClientRegistration.get(registrationId)); return Mono.justOrEmpty(this.clientIdToClientRegistration.get(registrationId));
@ -80,4 +77,12 @@ public final class InMemoryReactiveClientRegistrationRepository
public Iterator<ClientRegistration> iterator() { public Iterator<ClientRegistration> iterator() {
return this.clientIdToClientRegistration.values().iterator(); return this.clientIdToClientRegistration.values().iterator();
} }
private ConcurrentHashMap<String, ClientRegistration> toConcurrentMap(List<ClientRegistration> registrations) {
ConcurrentHashMap<String, ClientRegistration> result = new ConcurrentHashMap<>();
for (ClientRegistration registration : registrations) {
result.put(registration.getRegistrationId(), registration);
}
return result;
}
} }

View File

@ -229,6 +229,16 @@ public class OidcIdTokenValidatorTests {
.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
} }
@Test
public void validateFormatError() {
this.claims.remove(IdTokenClaimNames.SUB);
this.claims.remove(IdTokenClaimNames.AUD);
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch(msg -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}"));
}
private Collection<OAuth2Error> validateIdToken() { private Collection<OAuth2Error> validateIdToken() {
Jwt idToken = new Jwt("token123", this.issuedAt, this.expiresAt, this.headers, this.claims); Jwt idToken = new Jwt("token123", this.issuedAt, this.expiresAt, this.headers, this.claims);
OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build()); OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build());

View File

@ -50,12 +50,6 @@ public class InMemoryClientRegistrationRepositoryTests {
new InMemoryClientRegistrationRepository(registrations); new InMemoryClientRegistrationRepository(registrations);
} }
@Test(expected = IllegalStateException.class)
public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
new InMemoryClientRegistrationRepository(registrations);
}
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void constructorMapClientRegistrationWhenNullThenIllegalArgumentException() { public void constructorMapClientRegistrationWhenNullThenIllegalArgumentException() {
new InMemoryClientRegistrationRepository((Map<String, ClientRegistration>) null); new InMemoryClientRegistrationRepository((Map<String, ClientRegistration>) null);
@ -67,6 +61,12 @@ public class InMemoryClientRegistrationRepositoryTests {
assertThat(clients).isEmpty(); assertThat(clients).isEmpty();
} }
@Test(expected = IllegalStateException.class)
public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
new InMemoryClientRegistrationRepository(registrations);
}
@Test @Test
public void findByRegistrationIdWhenFoundThenFound() { public void findByRegistrationIdWhenFoundThenFound() {
String id = this.registration.getRegistrationId(); String id = this.registration.getRegistrationId();

View File

@ -23,9 +23,8 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.ArrayList;
/** /**
* @author Joe Grandja * @author Joe Grandja
@ -64,10 +63,13 @@ final class ObjectToListStringConverter implements ConditionalGenericConverter {
} }
} }
if (source instanceof Collection) { if (source instanceof Collection) {
return ((Collection<?>) source).stream() Collection<String> results = new ArrayList<>();
.filter(Objects::nonNull) for (Object object : ((Collection<?>) source)) {
.map(Objects::toString) if (object != null) {
.collect(Collectors.toList()); results.add(object.toString());
}
}
return results;
} }
return Collections.singletonList(source.toString()); return Collections.singletonList(source.toString());
} }

View File

@ -26,13 +26,11 @@ import org.springframework.web.util.UriComponentsBuilder;
import java.io.Serializable; import java.io.Serializable;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
/** /**
* A representation of an OAuth 2.0 Authorization Request * A representation of an OAuth 2.0 Authorization Request
@ -275,8 +273,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
*/ */
public Builder scope(String... scope) { public Builder scope(String... scope) {
if (scope != null && scope.length > 0) { if (scope != null && scope.length > 0) {
return this.scopes(Arrays.stream(scope).collect( return this.scopes(toLinkedHashSet(scope));
Collectors.toCollection(LinkedHashSet::new)));
} }
return this; return this;
} }
@ -401,5 +398,11 @@ public final class OAuth2AuthorizationRequest implements Serializable {
.build() .build()
.toUriString(); .toUriString();
} }
private LinkedHashSet<String> toLinkedHashSet(String... scope) {
LinkedHashSet<String> result = new LinkedHashSet<>();
Collections.addAll(result, scope);
return result;
}
} }
} }

View File

@ -37,14 +37,13 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Arrays; import java.util.HashSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.Arrays;
import java.util.stream.Stream; import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.HashMap;
/** /**
* A {@link HttpMessageConverter} for an {@link OAuth2AccessTokenResponse OAuth 2.0 Access Token Response}. * A {@link HttpMessageConverter} for an {@link OAuth2AccessTokenResponse OAuth 2.0 Access Token Response}.
@ -132,12 +131,13 @@ public class OAuth2AccessTokenResponseHttpMessageConverter extends AbstractHttpM
* OAuth 2.0 Access Token Response parameters to an {@link OAuth2AccessTokenResponse}. * OAuth 2.0 Access Token Response parameters to an {@link OAuth2AccessTokenResponse}.
*/ */
private static class OAuth2AccessTokenResponseConverter implements Converter<Map<String, String>, OAuth2AccessTokenResponse> { private static class OAuth2AccessTokenResponseConverter implements Converter<Map<String, String>, OAuth2AccessTokenResponse> {
private static final Set<String> TOKEN_RESPONSE_PARAMETER_NAMES = Stream.of( private static final Set<String> TOKEN_RESPONSE_PARAMETER_NAMES = new HashSet<>(Arrays.asList(
OAuth2ParameterNames.ACCESS_TOKEN, OAuth2ParameterNames.ACCESS_TOKEN,
OAuth2ParameterNames.TOKEN_TYPE, OAuth2ParameterNames.TOKEN_TYPE,
OAuth2ParameterNames.EXPIRES_IN, OAuth2ParameterNames.EXPIRES_IN,
OAuth2ParameterNames.REFRESH_TOKEN, OAuth2ParameterNames.REFRESH_TOKEN,
OAuth2ParameterNames.SCOPE).collect(Collectors.toSet()); OAuth2ParameterNames.SCOPE
));
@Override @Override
public OAuth2AccessTokenResponse convert(Map<String, String> tokenResponseParameters) { public OAuth2AccessTokenResponse convert(Map<String, String> tokenResponseParameters) {
@ -159,15 +159,17 @@ public class OAuth2AccessTokenResponseHttpMessageConverter extends AbstractHttpM
Set<String> scopes = Collections.emptySet(); Set<String> scopes = Collections.emptySet();
if (tokenResponseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { if (tokenResponseParameters.containsKey(OAuth2ParameterNames.SCOPE)) {
String scope = tokenResponseParameters.get(OAuth2ParameterNames.SCOPE); String scope = tokenResponseParameters.get(OAuth2ParameterNames.SCOPE);
scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet()); scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
} }
String refreshToken = tokenResponseParameters.get(OAuth2ParameterNames.REFRESH_TOKEN); String refreshToken = tokenResponseParameters.get(OAuth2ParameterNames.REFRESH_TOKEN);
Map<String, Object> additionalParameters = new LinkedHashMap<>(); Map<String, Object> additionalParameters = new LinkedHashMap<>();
tokenResponseParameters.entrySet().stream() for (Map.Entry<String, String> entry : tokenResponseParameters.entrySet()) {
.filter(e -> !TOKEN_RESPONSE_PARAMETER_NAMES.contains(e.getKey())) if (!TOKEN_RESPONSE_PARAMETER_NAMES.contains(entry.getKey())) {
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue())); additionalParameters.put(entry.getKey(), entry.getValue());
}
}
return OAuth2AccessTokenResponse.withToken(accessToken) return OAuth2AccessTokenResponse.withToken(accessToken)
.tokenType(accessTokenType) .tokenType(accessTokenType)
@ -205,8 +207,9 @@ public class OAuth2AccessTokenResponseHttpMessageConverter extends AbstractHttpM
parameters.put(OAuth2ParameterNames.REFRESH_TOKEN, tokenResponse.getRefreshToken().getTokenValue()); parameters.put(OAuth2ParameterNames.REFRESH_TOKEN, tokenResponse.getRefreshToken().getTokenValue());
} }
if (!CollectionUtils.isEmpty(tokenResponse.getAdditionalParameters())) { if (!CollectionUtils.isEmpty(tokenResponse.getAdditionalParameters())) {
tokenResponse.getAdditionalParameters().entrySet().stream() for (Map.Entry<String, Object> entry : tokenResponse.getAdditionalParameters().entrySet()) {
.forEach(e -> parameters.put(e.getKey(), e.getValue().toString())); parameters.put(entry.getKey(), entry.getValue().toString());
}
} }
return parameters; return parameters;

View File

@ -20,16 +20,15 @@ import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.SortedSet; import java.util.Collections;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.stream.Collectors; import java.util.SortedSet;
import java.util.Comparator;
import java.util.LinkedHashSet;
/** /**
* The default implementation of an {@link OAuth2User}. * The default implementation of an {@link OAuth2User}.
@ -43,8 +42,8 @@ import java.util.stream.Collectors;
* and returning it from {@link #getName()}. * and returning it from {@link #getName()}.
* *
* @author Joe Grandja * @author Joe Grandja
* @since 5.0
* @see OAuth2User * @see OAuth2User
* @since 5.0
*/ */
public class DefaultOAuth2User implements OAuth2User, Serializable { public class DefaultOAuth2User implements OAuth2User, Serializable {
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
@ -55,8 +54,8 @@ public class DefaultOAuth2User implements OAuth2User, Serializable {
/** /**
* Constructs a {@code DefaultOAuth2User} using the provided parameters. * Constructs a {@code DefaultOAuth2User} using the provided parameters.
* *
* @param authorities the authorities granted to the user * @param authorities the authorities granted to the user
* @param attributes the attributes about the user * @param attributes the attributes about the user
* @param nameAttributeKey the key used to access the user's &quot;name&quot; from {@link #getAttributes()} * @param nameAttributeKey the key used to access the user's &quot;name&quot; from {@link #getAttributes()}
*/ */
public DefaultOAuth2User(Collection<? extends GrantedAuthority> authorities, Map<String, Object> attributes, String nameAttributeKey) { public DefaultOAuth2User(Collection<? extends GrantedAuthority> authorities, Map<String, Object> attributes, String nameAttributeKey) {
@ -88,7 +87,7 @@ public class DefaultOAuth2User implements OAuth2User, Serializable {
private Set<GrantedAuthority> sortAuthorities(Collection<? extends GrantedAuthority> authorities) { private Set<GrantedAuthority> sortAuthorities(Collection<? extends GrantedAuthority> authorities) {
SortedSet<GrantedAuthority> sortedAuthorities = SortedSet<GrantedAuthority> sortedAuthorities =
new TreeSet<>(Comparator.comparing(GrantedAuthority::getAuthority)); new TreeSet<>(Comparator.comparing(GrantedAuthority::getAuthority));
sortedAuthorities.addAll(authorities); sortedAuthorities.addAll(authorities);
return sortedAuthorities; return sortedAuthorities;
} }
@ -127,9 +126,9 @@ public class DefaultOAuth2User implements OAuth2User, Serializable {
sb.append("Name: ["); sb.append("Name: [");
sb.append(this.getName()); sb.append(this.getName());
sb.append("], Granted Authorities: ["); sb.append("], Granted Authorities: [");
sb.append(this.getAuthorities().stream().map(GrantedAuthority::getAuthority).collect(Collectors.joining(", "))); sb.append(getAuthorities());
sb.append("], User Attributes: ["); sb.append("], User Attributes: [");
sb.append(this.getAttributes().entrySet().stream().map(e -> e.getKey() + "=" + e.getValue()).collect(Collectors.joining(", "))); sb.append(getAttributes());
sb.append("]"); sb.append("]");
return sb.toString(); return sb.toString();
} }

View File

@ -15,7 +15,6 @@
*/ */
package org.springframework.security.oauth2.jose.jws; package org.springframework.security.oauth2.jose.jws;
import java.util.stream.Stream;
/** /**
* An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification
@ -59,10 +58,12 @@ public enum MacAlgorithm implements JwsAlgorithm {
* @return the resolved {@code MacAlgorithm}, or {@code null} if not found * @return the resolved {@code MacAlgorithm}, or {@code null} if not found
*/ */
public static MacAlgorithm from(String name) { public static MacAlgorithm from(String name) {
return Stream.of(values()) for (MacAlgorithm algorithm : values()) {
.filter(algorithm -> algorithm.getName().equals(name)) if (algorithm.getName().equals(name)) {
.findFirst() return algorithm;
.orElse(null); }
}
return null;
} }
/** /**

View File

@ -15,8 +15,6 @@
*/ */
package org.springframework.security.oauth2.jose.jws; package org.springframework.security.oauth2.jose.jws;
import java.util.stream.Stream;
/** /**
* An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification
* and used by JSON Web Signature (JWS) to digitally sign the contents of the JWS Protected Header and JWS Payload. * and used by JSON Web Signature (JWS) to digitally sign the contents of the JWS Protected Header and JWS Payload.
@ -89,10 +87,12 @@ public enum SignatureAlgorithm implements JwsAlgorithm {
* @return the resolved {@code SignatureAlgorithm}, or {@code null} if not found * @return the resolved {@code SignatureAlgorithm}, or {@code null} if not found
*/ */
public static SignatureAlgorithm from(String name) { public static SignatureAlgorithm from(String name) {
return Stream.of(values()) for (SignatureAlgorithm value : values()) {
.filter(algorithm -> algorithm.getName().equals(name)) if (value.getName().equals(name)) {
.findFirst() return value;
.orElse(null); }
}
return null;
} }
/** /**

View File

@ -29,7 +29,6 @@ import java.time.Instant;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* Converts a JWT claim set, claim by claim. Can be configured with custom converters * Converts a JWT claim set, claim by claim. Can be configured with custom converters
@ -161,17 +160,22 @@ public final class MappedJwtClaimSetConverter implements Converter<Map<String, O
} }
private Map<String, Object> removeClaims(Map<String, Object> claims) { private Map<String, Object> removeClaims(Map<String, Object> claims) {
return claims.entrySet().stream() Map<String, Object> result = new HashMap<>();
.filter(e -> e.getValue() != null) for (Map.Entry<String, Object> entry : claims.entrySet()) {
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); if (entry.getValue() != null) {
result.put(entry.getKey(), entry.getValue());
}
}
return result;
} }
private Map<String, Object> addClaims(Map<String, Object> claims) { private Map<String, Object> addClaims(Map<String, Object> claims) {
Map<String, Object> result = new HashMap<>(claims); Map<String, Object> result = new HashMap<>(claims);
this.claimTypeConverters.entrySet().stream() for (Map.Entry<String, Converter<Object, ?>> entry : claimTypeConverters.entrySet()) {
.filter(e -> !claims.containsKey(e.getKey())) if (!claims.containsKey(entry.getKey()) && entry.getValue().convert(null) != null) {
.filter(e -> e.getValue().convert(null) != null) result.put(entry.getKey(), entry.getValue().convert(null));
.forEach(e -> result.put(e.getKey(), e.getValue().convert(null))); }
}
return result; return result;
} }
} }

View File

@ -22,7 +22,7 @@ import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.ArrayList;
import com.nimbusds.oauth2.sdk.TokenIntrospectionResponse; import com.nimbusds.oauth2.sdk.TokenIntrospectionResponse;
import com.nimbusds.oauth2.sdk.TokenIntrospectionSuccessResponse; import com.nimbusds.oauth2.sdk.TokenIntrospectionSuccessResponse;
@ -192,9 +192,11 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector {
private Map<String, Object> convertClaimsSet(TokenIntrospectionSuccessResponse response) { private Map<String, Object> convertClaimsSet(TokenIntrospectionSuccessResponse response) {
Map<String, Object> claims = response.toJSONObject(); Map<String, Object> claims = response.toJSONObject();
if (response.getAudience() != null) { if (response.getAudience() != null) {
List<String> audience = response.getAudience().stream() List<String> audiences = new ArrayList<>();
.map(Audience::getValue).collect(Collectors.toList()); for (Audience audience : response.getAudience()) {
claims.put(AUDIENCE, Collections.unmodifiableList(audience)); audiences.add(audience.getValue());
}
claims.put(AUDIENCE, Collections.unmodifiableList(audiences));
} }
if (response.getClientID() != null) { if (response.getClientID() != null) {
claims.put(CLIENT_ID, response.getClientID().getValue()); claims.put(CLIENT_ID, response.getClientID().getValue());

View File

@ -22,7 +22,7 @@ import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.ArrayList;
import com.nimbusds.oauth2.sdk.TokenIntrospectionResponse; import com.nimbusds.oauth2.sdk.TokenIntrospectionResponse;
import com.nimbusds.oauth2.sdk.TokenIntrospectionSuccessResponse; import com.nimbusds.oauth2.sdk.TokenIntrospectionSuccessResponse;
@ -153,9 +153,11 @@ public class NimbusReactiveOpaqueTokenIntrospector implements ReactiveOpaqueToke
private Map<String, Object> convertClaimsSet(TokenIntrospectionSuccessResponse response) { private Map<String, Object> convertClaimsSet(TokenIntrospectionSuccessResponse response) {
Map<String, Object> claims = response.toJSONObject(); Map<String, Object> claims = response.toJSONObject();
if (response.getAudience() != null) { if (response.getAudience() != null) {
List<String> audience = response.getAudience().stream() List<String> audiences = new ArrayList<>();
.map(Audience::getValue).collect(Collectors.toList()); for (Audience audience : response.getAudience()) {
claims.put(AUDIENCE, Collections.unmodifiableList(audience)); audiences.add(audience.getValue());
}
claims.put(AUDIENCE, Collections.unmodifiableList(audiences));
} }
if (response.getClientID() != null) { if (response.getClientID() != null) {
claims.put(CLIENT_ID, response.getClientID().getValue()); claims.put(CLIENT_ID, response.getClientID().getValue());

View File

@ -19,7 +19,6 @@ package org.springframework.security.oauth2.server.resource.web;
import java.io.IOException; import java.io.IOException;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -41,10 +40,10 @@ import org.springframework.util.StringUtils;
* {@code WWW-Authenticate} HTTP header. * {@code WWW-Authenticate} HTTP header.
* *
* @author Vedran Pavic * @author Vedran Pavic
* @since 5.1
* @see BearerTokenError * @see BearerTokenError
* @see <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate * @see <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate
* Response Header Field</a> * Response Header Field</a>
* @since 5.1
*/ */
public final class BearerTokenAuthenticationEntryPoint implements AuthenticationEntryPoint { public final class BearerTokenAuthenticationEntryPoint implements AuthenticationEntryPoint {
@ -54,8 +53,8 @@ public final class BearerTokenAuthenticationEntryPoint implements Authentication
* Collect error details from the provided parameters and format according to * Collect error details from the provided parameters and format according to
* RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}. * RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}.
* *
* @param request that resulted in an <code>AuthenticationException</code> * @param request that resulted in an <code>AuthenticationException</code>
* @param response so that the user agent can begin authentication * @param response so that the user agent can begin authentication
* @param authException that caused the invocation * @param authException that caused the invocation
*/ */
@Override @Override
@ -112,13 +111,22 @@ public final class BearerTokenAuthenticationEntryPoint implements Authentication
} }
private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) { private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) {
String wwwAuthenticate = "Bearer"; StringBuilder wwwAuthenticate = new StringBuilder();
wwwAuthenticate.append("Bearer");
if (!parameters.isEmpty()) { if (!parameters.isEmpty()) {
wwwAuthenticate += parameters.entrySet().stream() wwwAuthenticate.append(" ");
.map(attribute -> attribute.getKey() + "=\"" + attribute.getValue() + "\"") int i = 0;
.collect(Collectors.joining(", ", " ", "")); for (Map.Entry<String, String> entry : parameters.entrySet()) {
wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\"");
if (i != parameters.size() - 1) {
wwwAuthenticate.append(", ");
}
i++;
}
} }
return wwwAuthenticate; return wwwAuthenticate.toString();
} }
} }

View File

@ -30,12 +30,11 @@ import javax.servlet.http.HttpServletResponse;
import java.io.IOException; import java.io.IOException;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* Translates any {@link AccessDeniedException} into an HTTP response in accordance with * Translates any {@link AccessDeniedException} into an HTTP response in accordance with
* <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate</a>. * <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate</a>.
* * <p>
* So long as the class can prove that the request has a valid OAuth 2.0 {@link Authentication}, then will return an * So long as the class can prove that the request has a valid OAuth 2.0 {@link Authentication}, then will return an
* <a href="https://tools.ietf.org/html/rfc6750#section-3.1" target="_blank">insufficient scope error</a>; otherwise, * <a href="https://tools.ietf.org/html/rfc6750#section-3.1" target="_blank">insufficient scope error</a>; otherwise,
* it will simply indicate the scheme (Bearer) and any configured realm. * it will simply indicate the scheme (Bearer) and any configured realm.
@ -51,10 +50,9 @@ public final class BearerTokenAccessDeniedHandler implements AccessDeniedHandler
* Collect error details from the provided parameters and format according to * Collect error details from the provided parameters and format according to
* RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}. * RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}.
* *
* @param request that resulted in an <code>AccessDeniedException</code> * @param request that resulted in an <code>AccessDeniedException</code>
* @param response so that the user agent can be advised of the failure * @param response so that the user agent can be advised of the failure
* @param accessDeniedException that caused the invocation * @param accessDeniedException that caused the invocation
*
*/ */
@Override @Override
public void handle( public void handle(
@ -90,13 +88,22 @@ public final class BearerTokenAccessDeniedHandler implements AccessDeniedHandler
} }
private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) { private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) {
String wwwAuthenticate = "Bearer"; StringBuilder wwwAuthenticate = new StringBuilder();
wwwAuthenticate.append("Bearer");
if (!parameters.isEmpty()) { if (!parameters.isEmpty()) {
wwwAuthenticate += parameters.entrySet().stream() wwwAuthenticate.append(" ");
.map(attribute -> attribute.getKey() + "=\"" + attribute.getValue() + "\"") int i = 0;
.collect(Collectors.joining(", ", " ", "")); for (Map.Entry<String, String> entry : parameters.entrySet()) {
wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\"");
if (i != parameters.size() - 1) {
wwwAuthenticate.append(", ");
}
i++;
}
} }
return wwwAuthenticate; return wwwAuthenticate.toString();
} }
} }

View File

@ -30,7 +30,6 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* Translates any {@link AccessDeniedException} into an HTTP response in accordance with * Translates any {@link AccessDeniedException} into an HTTP response in accordance with
@ -91,13 +90,20 @@ public class BearerTokenServerAccessDeniedHandler implements ServerAccessDeniedH
} }
private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) { private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) {
String wwwAuthenticate = "Bearer"; StringBuilder wwwAuthenticate = new StringBuilder();
wwwAuthenticate.append("Bearer");
if (!parameters.isEmpty()) { if (!parameters.isEmpty()) {
wwwAuthenticate += parameters.entrySet().stream() wwwAuthenticate.append(" ");
.map(attribute -> attribute.getKey() + "=\"" + attribute.getValue() + "\"") int i = 0;
.collect(Collectors.joining(", ", " ", "")); for (Map.Entry<String, String> entry : parameters.entrySet()) {
wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\"");
if (i != parameters.size() - 1) {
wwwAuthenticate.append(", ");
}
i++;
}
} }
return wwwAuthenticate; return wwwAuthenticate.toString();
} }
} }

View File

@ -32,7 +32,6 @@ import reactor.core.publisher.Mono;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* An {@link AuthenticationEntryPoint} implementation used to commence authentication of protected resource requests * An {@link AuthenticationEntryPoint} implementation used to commence authentication of protected resource requests
@ -42,10 +41,10 @@ import java.util.stream.Collectors;
* {@code WWW-Authenticate} HTTP header. * {@code WWW-Authenticate} HTTP header.
* *
* @author Rob Winch * @author Rob Winch
* @since 5.1
* @see BearerTokenError * @see BearerTokenError
* @see <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate * @see <a href="https://tools.ietf.org/html/rfc6750#section-3" target="_blank">RFC 6750 Section 3: The WWW-Authenticate
* Response Header Field</a> * Response Header Field</a>
* @since 5.1
*/ */
public final class BearerTokenServerAuthenticationEntryPoint implements public final class BearerTokenServerAuthenticationEntryPoint implements
ServerAuthenticationEntryPoint { ServerAuthenticationEntryPoint {
@ -111,13 +110,21 @@ public final class BearerTokenServerAuthenticationEntryPoint implements
} }
private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) { private static String computeWWWAuthenticateHeaderValue(Map<String, String> parameters) {
String wwwAuthenticate = "Bearer"; StringBuilder wwwAuthenticate = new StringBuilder();
wwwAuthenticate.append("Bearer");
if (!parameters.isEmpty()) { if (!parameters.isEmpty()) {
wwwAuthenticate += parameters.entrySet().stream() wwwAuthenticate.append(" ");
.map(attribute -> attribute.getKey() + "=\"" + attribute.getValue() + "\"") int i = 0;
.collect(Collectors.joining(", ", " ", "")); for (Map.Entry<String, String> entry : parameters.entrySet()) {
wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\"");
if (i != parameters.size() - 1) {
wwwAuthenticate.append(", ");
}
i++;
}
} }
return wwwAuthenticate; return wwwAuthenticate.toString();
} }
} }

View File

@ -16,9 +16,6 @@
package org.springframework.security.web.header.writers; package org.springframework.security.web.header.writers;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -69,7 +66,7 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
public ClearSiteDataHeaderWriter(String ...sources) { public ClearSiteDataHeaderWriter(String ...sources) {
Assert.notEmpty(sources, "sources cannot be empty or null"); Assert.notEmpty(sources, "sources cannot be empty or null");
this.requestMatcher = new SecureRequestMatcher(); this.requestMatcher = new SecureRequestMatcher();
this.headerValue = Stream.of(sources).map(this::quote).collect(Collectors.joining(", ")); this.headerValue = joinQuotes(sources);
} }
@Override @Override
@ -84,6 +81,15 @@ public final class ClearSiteDataHeaderWriter implements HeaderWriter {
} }
} }
private String joinQuotes(String ...sources) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < sources.length-1; i++) {
sb.append(quote(sources[i])).append(", ");
}
sb.append(quote(sources[sources.length-1]));
return sb.toString();
}
private static final class SecureRequestMatcher implements RequestMatcher { private static final class SecureRequestMatcher implements RequestMatcher {
public boolean matches(HttpServletRequest request) { public boolean matches(HttpServletRequest request) {
return request.isSecure(); return request.isSecure();

View File

@ -23,8 +23,7 @@ import reactor.core.publisher.Mono;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.ArrayList;
import java.util.stream.Stream;
/** /**
* Delegates to a collection of {@link ServerAuthenticationSuccessHandler} implementations. * Delegates to a collection of {@link ServerAuthenticationSuccessHandler} implementations.
@ -43,7 +42,10 @@ public class DelegatingServerAuthenticationSuccessHandler implements ServerAuthe
@Override @Override
public Mono<Void> onAuthenticationSuccess(WebFilterExchange exchange, public Mono<Void> onAuthenticationSuccess(WebFilterExchange exchange,
Authentication authentication) { Authentication authentication) {
Stream<Mono<Void>> results = this.delegates.stream().map(delegate -> delegate.onAuthenticationSuccess(exchange, authentication)); List<Mono<Void>> results = new ArrayList<>();
return Mono.when(results.collect(Collectors.toList())); for (ServerAuthenticationSuccessHandler delegate : delegates) {
results.add(delegate.onAuthenticationSuccess(exchange, authentication));
}
return Mono.when(results);
} }
} }

View File

@ -20,8 +20,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -50,10 +48,12 @@ public class DelegatingServerLogoutHandler implements ServerLogoutHandler {
@Override @Override
public Mono<Void> logout(WebFilterExchange exchange, Authentication authentication) { public Mono<Void> logout(WebFilterExchange exchange, Authentication authentication) {
return Mono.when(this.delegates.stream() List<Mono<Void>> results = new ArrayList<>();
.filter(Objects::nonNull) for (ServerLogoutHandler delegate : delegates) {
.map(delegate -> delegate.logout(exchange, authentication)) if (delegate != null) {
.collect(Collectors.toList()) results.add(delegate.logout(exchange, authentication));
); }
}
return Mono.when(results);
} }
} }

View File

@ -20,9 +20,6 @@ import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* <p>Writes the {@code Clear-Site-Data} response header when the request is secure.</p> * <p>Writes the {@code Clear-Site-Data} response header when the request is secure.</p>
* *
@ -81,9 +78,12 @@ public final class ClearSiteDataServerHttpHeadersWriter implements ServerHttpHea
} }
private String transformToHeaderValue(Directive... directives) { private String transformToHeaderValue(Directive... directives) {
return Stream.of(directives) StringBuilder sb = new StringBuilder();
.map(Directive::getHeaderValue) for (int i = 0; i < directives.length - 1; i++) {
.collect(Collectors.joining(", ")); sb.append(directives[i].headerValue).append(", ");
}
sb.append(directives[directives.length - 1].headerValue);
return sb.toString();
} }
private boolean isSecure(ServerWebExchange exchange) { private boolean isSecure(ServerWebExchange exchange) {

View File

@ -17,8 +17,7 @@ package org.springframework.security.web.server.header;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.ArrayList;
import java.util.stream.Stream;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
@ -43,8 +42,10 @@ public class CompositeServerHttpHeadersWriter implements ServerHttpHeadersWriter
@Override @Override
public Mono<Void> writeHttpHeaders(ServerWebExchange exchange) { public Mono<Void> writeHttpHeaders(ServerWebExchange exchange) {
Stream<Mono<Void>> results = writers.stream().map( writer -> writer.writeHttpHeaders(exchange)); List<Mono<Void>> results = new ArrayList<>();
return Mono.when(results.collect(Collectors.toList())); for (ServerHttpHeadersWriter writer : writers) {
results.add(writer.writeHttpHeaders(exchange));
}
return Mono.when(results);
} }
} }