diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java index 4210e90fe6..41ab29c93a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,14 @@ package org.springframework.security.config.annotation.web.configuration; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import java.util.function.Supplier; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -36,7 +40,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; @@ -68,17 +71,22 @@ class SecurityReactorContextConfiguration { private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR"; + private static final Map> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>(); + + static { + CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class, + SecurityReactorContextSubscriberRegistrar::getRequest); + CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class, + SecurityReactorContextSubscriberRegistrar::getResponse); + CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, + SecurityReactorContextSubscriberRegistrar::getAuthentication); + } + @Override public void afterPropertiesSet() throws Exception { Function, ? extends Publisher> lifter = Operators .liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub)); - Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> { - if (!contextAttributesAvailable()) { - // No need to decorate so return original Publisher - return pub; - } - return lifter.apply(pub); - }); + Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, lifter::apply); } @Override @@ -94,45 +102,30 @@ class SecurityReactorContextConfiguration { return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes()); } - private static boolean contextAttributesAvailable() { - SecurityContext context = SecurityContextHolder.peekContext(); - Authentication authentication = null; - if (context != null) { - authentication = context.getAuthentication(); - } - return authentication != null - || RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes; + private static Map getContextAttributes() { + return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS); } - private static Map getContextAttributes() { - HttpServletRequest servletRequest = null; - HttpServletResponse servletResponse = null; + private static HttpServletRequest getRequest() { RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); if (requestAttributes instanceof ServletRequestAttributes) { ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes; - servletRequest = servletRequestAttributes.getRequest(); - servletResponse = servletRequestAttributes.getResponse(); // possible null - } - SecurityContext context = SecurityContextHolder.peekContext(); - Authentication authentication = null; - if (context != null) { - authentication = context.getAuthentication(); - } - if (authentication == null && servletRequest == null) { - return Collections.emptyMap(); - } - Map contextAttributes = new HashMap<>(); - if (servletRequest != null) { - contextAttributes.put(HttpServletRequest.class, servletRequest); - } - if (servletResponse != null) { - contextAttributes.put(HttpServletResponse.class, servletResponse); - } - if (authentication != null) { - contextAttributes.put(Authentication.class, authentication); + return servletRequestAttributes.getRequest(); } + return null; + } - return contextAttributes; + private static HttpServletResponse getResponse() { + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (requestAttributes instanceof ServletRequestAttributes) { + ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes; + return servletRequestAttributes.getResponse(); // possible null + } + return null; + } + + private static Authentication getAuthentication() { + return SecurityContextHolder.getContext().getAuthentication(); } } @@ -185,4 +178,112 @@ class SecurityReactorContextConfiguration { } + /** + * A map that computes each value when {@link #get} is invoked + */ + static class LoadingMap implements Map { + + private final Map loaded = new ConcurrentHashMap<>(); + + private final Map> loaders; + + LoadingMap(Map> loaders) { + this.loaders = Collections.unmodifiableMap(new HashMap<>(loaders)); + } + + @Override + public int size() { + return this.loaders.size(); + } + + @Override + public boolean isEmpty() { + return this.loaders.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return this.loaders.containsKey(key); + } + + @Override + public Set keySet() { + return this.loaders.keySet(); + } + + @Override + public V get(Object key) { + if (!this.loaders.containsKey(key)) { + throw new IllegalArgumentException( + "This map only supports the following keys: " + this.loaders.keySet()); + } + return this.loaded.computeIfAbsent((K) key, (k) -> this.loaders.get(k).get()); + } + + @Override + public V put(K key, V value) { + if (!this.loaders.containsKey(key)) { + throw new IllegalArgumentException( + "This map only supports the following keys: " + this.loaders.keySet()); + } + return this.loaded.put(key, value); + } + + @Override + public V remove(Object key) { + if (!this.loaders.containsKey(key)) { + throw new IllegalArgumentException( + "This map only supports the following keys: " + this.loaders.keySet()); + } + return this.loaded.remove(key); + } + + @Override + public void putAll(Map m) { + for (Map.Entry entry : m.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public void clear() { + this.loaded.clear(); + } + + @Override + public boolean containsValue(Object value) { + return this.loaded.containsValue(value); + } + + @Override + public Collection values() { + return this.loaded.values(); + } + + @Override + public Set> entrySet() { + return this.loaded.entrySet(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + LoadingMap that = (LoadingMap) o; + + return this.loaded.equals(that.loaded); + } + + @Override + public int hashCode() { + return this.loaded.hashCode(); + } + + } + } diff --git a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java index 330afdb743..d8367c4ebd 100644 --- a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java @@ -44,11 +44,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder return contextHolder; } - @Override - public SecurityContext peekContext() { - return contextHolder; - } - @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); diff --git a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java index d08b221c28..cb415500ca 100644 --- a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java @@ -44,11 +44,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur return ctx; } - @Override - public SecurityContext peekContext() { - return contextHolder.get(); - } - @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java index 6671e6dd9d..337fde3a57 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java @@ -123,14 +123,6 @@ public class SecurityContextHolder { return strategy.getContext(); } - /** - * Peeks the current SecurityContext. - * @return the security context (may be null) - */ - public static SecurityContext peekContext() { - return strategy.peekContext(); - } - /** * Primarily for troubleshooting purposes, this method shows how many times the class * has re-initialized its SecurityContextHolderStrategy. diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java index 2a29566fae..4954db70aa 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java @@ -38,12 +38,6 @@ public interface SecurityContextHolderStrategy { */ SecurityContext getContext(); - /** - * Peeks the current context without creating an empty context. - * @return a context (may be null) - */ - SecurityContext peekContext(); - /** * Sets the current context. * @param context to the new argument (should never be null, although diff --git a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java index 84f23bbe22..801f5c8207 100644 --- a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java @@ -45,11 +45,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH return ctx; } - @Override - public SecurityContext peekContext() { - return contextHolder.get(); - } - @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted");