|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/*
|
|
|
|
|
* Copyright 2002-2020 the original author or authors.
|
|
|
|
|
* Copyright 2002-2021 the original author or authors.
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -16,10 +16,14 @@
|
|
|
|
|
|
|
|
|
|
package org.springframework.security.config.annotation.web.configuration;
|
|
|
|
|
|
|
|
|
|
import java.util.Collection;
|
|
|
|
|
import java.util.Collections;
|
|
|
|
|
import java.util.HashMap;
|
|
|
|
|
import java.util.Map;
|
|
|
|
|
import java.util.Set;
|
|
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
|
|
import java.util.function.Function;
|
|
|
|
|
import java.util.function.Supplier;
|
|
|
|
|
|
|
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
|
|
@ -36,7 +40,6 @@ import org.springframework.beans.factory.InitializingBean;
|
|
|
|
|
import org.springframework.context.annotation.Bean;
|
|
|
|
|
import org.springframework.context.annotation.Configuration;
|
|
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
|
|
import org.springframework.security.core.context.SecurityContext;
|
|
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
|
|
import org.springframework.web.context.request.RequestAttributes;
|
|
|
|
|
import org.springframework.web.context.request.RequestContextHolder;
|
|
|
|
@ -68,17 +71,22 @@ class SecurityReactorContextConfiguration {
|
|
|
|
|
|
|
|
|
|
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
|
|
|
|
|
|
|
|
|
|
private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
|
|
|
|
|
|
|
|
|
|
static {
|
|
|
|
|
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
|
|
|
|
|
SecurityReactorContextSubscriberRegistrar::getRequest);
|
|
|
|
|
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
|
|
|
|
|
SecurityReactorContextSubscriberRegistrar::getResponse);
|
|
|
|
|
CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
|
|
|
|
|
SecurityReactorContextSubscriberRegistrar::getAuthentication);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void afterPropertiesSet() throws Exception {
|
|
|
|
|
Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter = Operators
|
|
|
|
|
.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
|
|
|
|
|
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> {
|
|
|
|
|
if (!contextAttributesAvailable()) {
|
|
|
|
|
// No need to decorate so return original Publisher
|
|
|
|
|
return pub;
|
|
|
|
|
}
|
|
|
|
|
return lifter.apply(pub);
|
|
|
|
|
});
|
|
|
|
|
Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, lifter::apply);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
@ -94,45 +102,30 @@ class SecurityReactorContextConfiguration {
|
|
|
|
|
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private static boolean contextAttributesAvailable() {
|
|
|
|
|
SecurityContext context = SecurityContextHolder.peekContext();
|
|
|
|
|
Authentication authentication = null;
|
|
|
|
|
if (context != null) {
|
|
|
|
|
authentication = context.getAuthentication();
|
|
|
|
|
}
|
|
|
|
|
return authentication != null
|
|
|
|
|
|| RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
|
|
|
|
|
private static Map<Object, Object> getContextAttributes() {
|
|
|
|
|
return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private static Map<Object, Object> getContextAttributes() {
|
|
|
|
|
HttpServletRequest servletRequest = null;
|
|
|
|
|
HttpServletResponse servletResponse = null;
|
|
|
|
|
private static HttpServletRequest getRequest() {
|
|
|
|
|
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
|
|
|
|
|
if (requestAttributes instanceof ServletRequestAttributes) {
|
|
|
|
|
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
|
|
|
|
|
servletRequest = servletRequestAttributes.getRequest();
|
|
|
|
|
servletResponse = servletRequestAttributes.getResponse(); // possible null
|
|
|
|
|
}
|
|
|
|
|
SecurityContext context = SecurityContextHolder.peekContext();
|
|
|
|
|
Authentication authentication = null;
|
|
|
|
|
if (context != null) {
|
|
|
|
|
authentication = context.getAuthentication();
|
|
|
|
|
}
|
|
|
|
|
if (authentication == null && servletRequest == null) {
|
|
|
|
|
return Collections.emptyMap();
|
|
|
|
|
}
|
|
|
|
|
Map<Object, Object> contextAttributes = new HashMap<>();
|
|
|
|
|
if (servletRequest != null) {
|
|
|
|
|
contextAttributes.put(HttpServletRequest.class, servletRequest);
|
|
|
|
|
}
|
|
|
|
|
if (servletResponse != null) {
|
|
|
|
|
contextAttributes.put(HttpServletResponse.class, servletResponse);
|
|
|
|
|
}
|
|
|
|
|
if (authentication != null) {
|
|
|
|
|
contextAttributes.put(Authentication.class, authentication);
|
|
|
|
|
return servletRequestAttributes.getRequest();
|
|
|
|
|
}
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return contextAttributes;
|
|
|
|
|
private static HttpServletResponse getResponse() {
|
|
|
|
|
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
|
|
|
|
|
if (requestAttributes instanceof ServletRequestAttributes) {
|
|
|
|
|
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
|
|
|
|
|
return servletRequestAttributes.getResponse(); // possible null
|
|
|
|
|
}
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private static Authentication getAuthentication() {
|
|
|
|
|
return SecurityContextHolder.getContext().getAuthentication();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
@ -185,4 +178,112 @@ class SecurityReactorContextConfiguration {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* A map that computes each value when {@link #get} is invoked
|
|
|
|
|
*/
|
|
|
|
|
static class LoadingMap<K, V> implements Map<K, V> {
|
|
|
|
|
|
|
|
|
|
private final Map<K, V> loaded = new ConcurrentHashMap<>();
|
|
|
|
|
|
|
|
|
|
private final Map<K, Supplier<V>> loaders;
|
|
|
|
|
|
|
|
|
|
LoadingMap(Map<K, Supplier<V>> loaders) {
|
|
|
|
|
this.loaders = Collections.unmodifiableMap(new HashMap<>(loaders));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public int size() {
|
|
|
|
|
return this.loaders.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean isEmpty() {
|
|
|
|
|
return this.loaders.isEmpty();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean containsKey(Object key) {
|
|
|
|
|
return this.loaders.containsKey(key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Set<K> keySet() {
|
|
|
|
|
return this.loaders.keySet();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public V get(Object key) {
|
|
|
|
|
if (!this.loaders.containsKey(key)) {
|
|
|
|
|
throw new IllegalArgumentException(
|
|
|
|
|
"This map only supports the following keys: " + this.loaders.keySet());
|
|
|
|
|
}
|
|
|
|
|
return this.loaded.computeIfAbsent((K) key, (k) -> this.loaders.get(k).get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public V put(K key, V value) {
|
|
|
|
|
if (!this.loaders.containsKey(key)) {
|
|
|
|
|
throw new IllegalArgumentException(
|
|
|
|
|
"This map only supports the following keys: " + this.loaders.keySet());
|
|
|
|
|
}
|
|
|
|
|
return this.loaded.put(key, value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public V remove(Object key) {
|
|
|
|
|
if (!this.loaders.containsKey(key)) {
|
|
|
|
|
throw new IllegalArgumentException(
|
|
|
|
|
"This map only supports the following keys: " + this.loaders.keySet());
|
|
|
|
|
}
|
|
|
|
|
return this.loaded.remove(key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void putAll(Map<? extends K, ? extends V> m) {
|
|
|
|
|
for (Map.Entry<? extends K, ? extends V> entry : m.entrySet()) {
|
|
|
|
|
put(entry.getKey(), entry.getValue());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void clear() {
|
|
|
|
|
this.loaded.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean containsValue(Object value) {
|
|
|
|
|
return this.loaded.containsValue(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Collection<V> values() {
|
|
|
|
|
return this.loaded.values();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Set<Entry<K, V>> entrySet() {
|
|
|
|
|
return this.loaded.entrySet();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean equals(Object o) {
|
|
|
|
|
if (this == o) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (o == null || getClass() != o.getClass()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LoadingMap<?, ?> that = (LoadingMap<?, ?>) o;
|
|
|
|
|
|
|
|
|
|
return this.loaded.equals(that.loaded);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public int hashCode() {
|
|
|
|
|
return this.loaded.hashCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|