Polish Memory Leak Mitigation

Issue gh-9841
This commit is contained in:
Josh Cummings 2021-11-30 12:32:36 -07:00
parent 809ff883b0
commit 78857c62f4
6 changed files with 142 additions and 70 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,10 +16,14 @@
package org.springframework.security.config.annotation.web.configuration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
@ -36,7 +40,6 @@ import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
@ -68,17 +71,22 @@ class SecurityReactorContextConfiguration {
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
static {
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
SecurityReactorContextSubscriberRegistrar::getRequest);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
SecurityReactorContextSubscriberRegistrar::getResponse);
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
SecurityReactorContextSubscriberRegistrar::getAuthentication);
}
@Override
public void afterPropertiesSet() throws Exception {
Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter = Operators
.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> {
if (!contextAttributesAvailable()) {
// No need to decorate so return original Publisher
return pub;
}
return lifter.apply(pub);
});
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, lifter::apply);
}
@Override
@ -94,45 +102,30 @@ class SecurityReactorContextConfiguration {
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
}
private static boolean contextAttributesAvailable() {
SecurityContext context = SecurityContextHolder.peekContext();
Authentication authentication = null;
if (context != null) {
authentication = context.getAuthentication();
}
return authentication != null
|| RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
private static Map<Object, Object> getContextAttributes() {
return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
}
private static Map<Object, Object> getContextAttributes() {
HttpServletRequest servletRequest = null;
HttpServletResponse servletResponse = null;
private static HttpServletRequest getRequest() {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) {
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
servletRequest = servletRequestAttributes.getRequest();
servletResponse = servletRequestAttributes.getResponse(); // possible null
}
SecurityContext context = SecurityContextHolder.peekContext();
Authentication authentication = null;
if (context != null) {
authentication = context.getAuthentication();
}
if (authentication == null && servletRequest == null) {
return Collections.emptyMap();
}
Map<Object, Object> contextAttributes = new HashMap<>();
if (servletRequest != null) {
contextAttributes.put(HttpServletRequest.class, servletRequest);
}
if (servletResponse != null) {
contextAttributes.put(HttpServletResponse.class, servletResponse);
}
if (authentication != null) {
contextAttributes.put(Authentication.class, authentication);
return servletRequestAttributes.getRequest();
}
return null;
}
return contextAttributes;
private static HttpServletResponse getResponse() {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) {
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
return servletRequestAttributes.getResponse(); // possible null
}
return null;
}
private static Authentication getAuthentication() {
return SecurityContextHolder.getContext().getAuthentication();
}
}
@ -185,4 +178,112 @@ class SecurityReactorContextConfiguration {
}
/**
* A map that computes each value when {@link #get} is invoked
*/
static class LoadingMap<K, V> implements Map<K, V> {
private final Map<K, V> loaded = new ConcurrentHashMap<>();
private final Map<K, Supplier<V>> loaders;
LoadingMap(Map<K, Supplier<V>> loaders) {
this.loaders = Collections.unmodifiableMap(new HashMap<>(loaders));
}
@Override
public int size() {
return this.loaders.size();
}
@Override
public boolean isEmpty() {
return this.loaders.isEmpty();
}
@Override
public boolean containsKey(Object key) {
return this.loaders.containsKey(key);
}
@Override
public Set<K> keySet() {
return this.loaders.keySet();
}
@Override
public V get(Object key) {
if (!this.loaders.containsKey(key)) {
throw new IllegalArgumentException(
"This map only supports the following keys: " + this.loaders.keySet());
}
return this.loaded.computeIfAbsent((K) key, (k) -> this.loaders.get(k).get());
}
@Override
public V put(K key, V value) {
if (!this.loaders.containsKey(key)) {
throw new IllegalArgumentException(
"This map only supports the following keys: " + this.loaders.keySet());
}
return this.loaded.put(key, value);
}
@Override
public V remove(Object key) {
if (!this.loaders.containsKey(key)) {
throw new IllegalArgumentException(
"This map only supports the following keys: " + this.loaders.keySet());
}
return this.loaded.remove(key);
}
@Override
public void putAll(Map<? extends K, ? extends V> m) {
for (Map.Entry<? extends K, ? extends V> entry : m.entrySet()) {
put(entry.getKey(), entry.getValue());
}
}
@Override
public void clear() {
this.loaded.clear();
}
@Override
public boolean containsValue(Object value) {
return this.loaded.containsValue(value);
}
@Override
public Collection<V> values() {
return this.loaded.values();
}
@Override
public Set<Entry<K, V>> entrySet() {
return this.loaded.entrySet();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LoadingMap<?, ?> that = (LoadingMap<?, ?>) o;
return this.loaded.equals(that.loaded);
}
@Override
public int hashCode() {
return this.loaded.hashCode();
}
}
}

View File

@ -44,11 +44,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder
return contextHolder;
}
@Override
public SecurityContext peekContext() {
return contextHolder;
}
@Override
public void setContext(SecurityContext context) {
Assert.notNull(context, "Only non-null SecurityContext instances are permitted");

View File

@ -44,11 +44,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur
return ctx;
}
@Override
public SecurityContext peekContext() {
return contextHolder.get();
}
@Override
public void setContext(SecurityContext context) {
Assert.notNull(context, "Only non-null SecurityContext instances are permitted");

View File

@ -123,14 +123,6 @@ public class SecurityContextHolder {
return strategy.getContext();
}
/**
* Peeks the current <code>SecurityContext</code>.
* @return the security context (may be <code>null</code>)
*/
public static SecurityContext peekContext() {
return strategy.peekContext();
}
/**
* Primarily for troubleshooting purposes, this method shows how many times the class
* has re-initialized its <code>SecurityContextHolderStrategy</code>.

View File

@ -38,12 +38,6 @@ public interface SecurityContextHolderStrategy {
*/
SecurityContext getContext();
/**
* Peeks the current context without creating an empty context.
* @return a context (may be <code>null</code>)
*/
SecurityContext peekContext();
/**
* Sets the current context.
* @param context to the new argument (should never be <code>null</code>, although

View File

@ -45,11 +45,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH
return ctx;
}
@Override
public SecurityContext peekContext() {
return contextHolder.get();
}
@Override
public void setContext(SecurityContext context) {
Assert.notNull(context, "Only non-null SecurityContext instances are permitted");