From 0012e24c46872e2159beeee41b0f024beeee9e7f Mon Sep 17 00:00:00 2001 From: Stephane Maldini Date: Thu, 6 Feb 2020 14:16:38 -0800 Subject: [PATCH] Don't force downcasting of RequestAttributes to ServletRequestAttributes Fixes gh-7953 --- .../SecurityReactorContextConfiguration.java | 30 ++++------- ...urityReactorContextConfigurationTests.java | 50 ++++++++++++++++++- .../DefaultOAuth2AuthorizedClientManager.java | 15 +++--- ...uthorizedClientExchangeFilterFunction.java | 15 +++--- 4 files changed, 73 insertions(+), 37 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java index a4e8fe0df6..8d76982c80 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import reactor.core.CoreSubscriber; @@ -92,32 +93,21 @@ class SecurityReactorContextConfiguration { } private static boolean contextAttributesAvailable() { - HttpServletRequest servletRequest = null; - HttpServletResponse servletResponse = null; - ServletRequestAttributes requestAttributes = - (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (requestAttributes != null) { - servletRequest = requestAttributes.getRequest(); - servletResponse = requestAttributes.getResponse(); - } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication != null || servletRequest != null || servletResponse != null) { - return true; - } - return false; + return SecurityContextHolder.getContext().getAuthentication() != null || + RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes; } private static Map getContextAttributes() { HttpServletRequest servletRequest = null; HttpServletResponse servletResponse = null; - ServletRequestAttributes requestAttributes = - (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (requestAttributes != null) { - servletRequest = requestAttributes.getRequest(); - servletResponse = requestAttributes.getResponse(); + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (requestAttributes instanceof ServletRequestAttributes) { + ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes; + servletRequest = servletRequestAttributes.getRequest(); + servletResponse = servletRequestAttributes.getResponse(); // possible null } Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication == null && servletRequest == null && servletResponse == null) { + if (authentication == null && servletRequest == null) { return Collections.emptyMap(); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java index b26cef84b7..11d8317dde 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; +import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.client.ClientRequest; @@ -36,6 +37,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import reactor.core.CoreSubscriber; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import reactor.test.StepVerifier; import reactor.util.context.Context; @@ -139,6 +141,52 @@ public class SecurityReactorContextConfigurationTests { assertThat(resultContext).isSameAs(parentContext); } + @Test + public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() { + RequestContextHolder.setRequestAttributes( + new RequestAttributes() { + @Override + public Object getAttribute(String name, int scope) { + return null; + } + + @Override + public void setAttribute(String name, Object value, int scope) { + } + + @Override + public void removeAttribute(String name, int scope) { + } + + @Override + public String[] getAttributeNames(int scope) { + return new String[0]; + } + + @Override + public void registerDestructionCallback(String name, Runnable callback, int scope) { + } + + @Override + public Object resolveReference(String key) { + return null; + } + + @Override + public String getSessionId() { + return null; + } + + @Override + public Object getSessionMutex() { + return null; + } + }); + + CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber()); + assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class); + } + @Test public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() { // Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 19719dc7c4..85b614c070 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -121,9 +122,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori private static HttpServletRequest getHttpServletRequestOrDefault(Map attributes) { HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()); if (servletRequest == null) { - ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (context != null) { - servletRequest = context.getRequest(); + RequestAttributes context = RequestContextHolder.getRequestAttributes(); + if (context instanceof ServletRequestAttributes) { + servletRequest = ((ServletRequestAttributes) context).getRequest(); } } return servletRequest; @@ -132,9 +133,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori private static HttpServletResponse getHttpServletResponseOrDefault(Map attributes) { HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()); if (servletResponse == null) { - ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (context != null) { - servletResponse = context.getResponse(); + RequestAttributes context = RequestContextHolder.getRequestAttributes(); + if (context instanceof ServletRequestAttributes) { + servletResponse = ((ServletRequestAttributes) context).getResponse(); } } return servletResponse; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 147e3255c0..22d488c965 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.client.ClientRequest; @@ -389,15 +390,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { return; } - ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - HttpServletRequest request = null; - HttpServletResponse response = null; - if (context != null) { - request = context.getRequest(); - response = context.getResponse(); + RequestAttributes context = RequestContextHolder.getRequestAttributes(); + if (context instanceof ServletRequestAttributes) { + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest()); + attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse()); } - attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request); - attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); } private void populateDefaultAuthentication(Map attrs) {