diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/configuration/OidcSecurityConfiguration.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/configuration/OidcSecurityConfiguration.java index ac37255590..2ced302598 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/configuration/OidcSecurityConfiguration.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/configuration/OidcSecurityConfiguration.java @@ -38,6 +38,7 @@ import org.apache.nifi.web.security.oidc.OidcConfigurationException; import org.apache.nifi.web.security.oidc.OidcUrlPath; import org.apache.nifi.web.security.oidc.client.web.AuthorizedClientExpirationCommand; import org.apache.nifi.web.security.oidc.client.web.OidcBearerTokenRefreshFilter; +import org.apache.nifi.web.security.oidc.client.web.StandardOAuth2AuthorizationRequestResolver; import org.apache.nifi.web.security.oidc.client.web.converter.AuthenticationResultConverter; import org.apache.nifi.web.security.oidc.client.web.converter.AuthorizedClientConverter; import org.apache.nifi.web.security.oidc.client.web.StandardAuthorizationRequestRepository; @@ -187,7 +188,8 @@ public class OidcSecurityConfiguration { */ @Bean public OAuth2AuthorizationRequestRedirectFilter oAuth2AuthorizationRequestRedirectFilter() { - final OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(clientRegistrationRepository()); + final StandardOAuth2AuthorizationRequestResolver authorizationRequestResolver = new StandardOAuth2AuthorizationRequestResolver(clientRegistrationRepository()); + final OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(authorizationRequestResolver); filter.setAuthorizationRequestRepository(authorizationRequestRepository()); filter.setRequestCache(nullRequestCache); return filter; diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolver.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolver.java new file mode 100644 index 0000000000..dd8ca86ee3 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolver.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.oidc.client.web; + +import org.apache.nifi.web.util.RequestUriBuilder; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.web.util.UriComponentsBuilder; + +import javax.servlet.http.HttpServletRequest; +import java.net.URI; +import java.util.Objects; + +/** + * Authorization Request Resolver supports handling of headers from reverse proxy servers + */ +public class StandardOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver { + private final OAuth2AuthorizationRequestResolver resolver; + + /** + * Resolver constructor delegates to the Spring Security Default Resolver and uses the default Request Base URI + * + * @param clientRegistrationRepository Client Registration Repository + */ + public StandardOAuth2AuthorizationRequestResolver(final ClientRegistrationRepository clientRegistrationRepository) { + Objects.requireNonNull(clientRegistrationRepository, "Repository required"); + resolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); + } + + /** + * Resolve Authorization Request delegating to default resolver + * + * @param request HTTP Servlet Request + * @return OAuth2 Authorization Request or null when not resolved + */ + @Override + public OAuth2AuthorizationRequest resolve(final HttpServletRequest request) { + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request); + return getResolvedAuthorizationRequest(authorizationRequest, request); + } + + /** + * Resolve Authorization Request delegating to default resolver + * + * @param request HTTP Servlet Request + * @param clientRegistrationId Client Registration Identifier + * @return OAuth2 Authorization Request or null when not resolved + */ + @Override + public OAuth2AuthorizationRequest resolve(final HttpServletRequest request, final String clientRegistrationId) { + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request, clientRegistrationId); + return getResolvedAuthorizationRequest(authorizationRequest, request); + } + + private OAuth2AuthorizationRequest getResolvedAuthorizationRequest(final OAuth2AuthorizationRequest authorizationRequest, final HttpServletRequest request) { + final OAuth2AuthorizationRequest resolvedAuthorizationRequest; + + if (authorizationRequest == null) { + resolvedAuthorizationRequest = null; + } else { + final String redirectUri = authorizationRequest.getRedirectUri(); + if (redirectUri == null) { + resolvedAuthorizationRequest = authorizationRequest; + } else { + final String requestBasedRedirectUri = getRequestBasedRedirectUri(redirectUri, request); + resolvedAuthorizationRequest = OAuth2AuthorizationRequest.from(authorizationRequest).redirectUri(requestBasedRedirectUri).build(); + } + } + + return resolvedAuthorizationRequest; + } + + private String getRequestBasedRedirectUri(final String redirectUri, final HttpServletRequest request) { + final String redirectUriPath = UriComponentsBuilder.fromUriString(redirectUri).build().getPath(); + final URI baseUri = RequestUriBuilder.fromHttpServletRequest(request).path(redirectUriPath).build(); + return UriComponentsBuilder.fromUriString(redirectUri) + .scheme(baseUri.getScheme()) + .host(baseUri.getHost()) + .port(baseUri.getPort()) + .replacePath(baseUri.getPath()) + .build() + .toUriString(); + } +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolverTest.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolverTest.java new file mode 100644 index 0000000000..5644c68f66 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/oidc/client/web/StandardOAuth2AuthorizationRequestResolverTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.oidc.client.web; + +import org.apache.nifi.web.util.WebUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; + +import javax.servlet.ServletContext; +import java.net.URI; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class StandardOAuth2AuthorizationRequestResolverTest { + private static final String REDIRECT_URI = "https://localhost:8443/nifi-api/callback"; + + private static final String FORWARDED_PATH = "/forwarded"; + + private static final String FORWARDED_REDIRECT_URI = String.format("https://localhost.localdomain%s/nifi-api/callback", FORWARDED_PATH); + + private static final String ALLOWED_CONTEXT_PATHS_PARAMETER = "allowedContextPaths"; + + private static final String AUTHORIZATION_URI = "http://localhost/authorize"; + + private static final String TOKEN_URI = "http://localhost/token"; + + private static final String CLIENT_ID = "client-id"; + + private static final String REGISTRATION_ID = OidcRegistrationProperty.REGISTRATION_ID.getProperty(); + + MockHttpServletRequest httpServletRequest; + + MockHttpServletResponse httpServletResponse; + + @Mock + ClientRegistrationRepository clientRegistrationRepository; + + StandardOAuth2AuthorizationRequestResolver resolver; + + @BeforeEach + void setResolver() { + resolver = new StandardOAuth2AuthorizationRequestResolver(clientRegistrationRepository); + httpServletRequest = new MockHttpServletRequest(); + httpServletResponse = new MockHttpServletResponse(); + } + + @Test + void testResolveNotFound() { + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(httpServletRequest); + + assertNull(authorizationRequest); + } + + @Test + void testResolveClientRegistrationIdNotFound() { + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(httpServletRequest, null); + + assertNull(authorizationRequest); + } + + @Test + void testResolveFound() { + final URI redirectUri = URI.create(REDIRECT_URI); + httpServletRequest.setScheme(redirectUri.getScheme()); + httpServletRequest.setServerPort(redirectUri.getPort()); + + final ClientRegistration clientRegistration = getClientRegistration(); + when(clientRegistrationRepository.findByRegistrationId(eq(REGISTRATION_ID))).thenReturn(clientRegistration); + + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(httpServletRequest, REGISTRATION_ID); + + assertNotNull(authorizationRequest); + assertEquals(REDIRECT_URI, authorizationRequest.getRedirectUri()); + } + + @Test + void testResolveFoundRedirectUriProxyHeaders() { + final ClientRegistration clientRegistration = getClientRegistration(); + when(clientRegistrationRepository.findByRegistrationId(eq(REGISTRATION_ID))).thenReturn(clientRegistration); + + final ServletContext servletContext = httpServletRequest.getServletContext(); + servletContext.setInitParameter(ALLOWED_CONTEXT_PATHS_PARAMETER, FORWARDED_PATH); + + final URI forwardedRedirectUri = URI.create(FORWARDED_REDIRECT_URI); + httpServletRequest.addHeader(WebUtils.PROXY_SCHEME_HTTP_HEADER, forwardedRedirectUri.getScheme()); + httpServletRequest.addHeader(WebUtils.PROXY_HOST_HTTP_HEADER, forwardedRedirectUri.getHost()); + httpServletRequest.addHeader(WebUtils.PROXY_PORT_HTTP_HEADER, forwardedRedirectUri.getPort()); + httpServletRequest.addHeader(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER, FORWARDED_PATH); + + final OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(httpServletRequest, REGISTRATION_ID); + + assertNotNull(authorizationRequest); + assertEquals(FORWARDED_REDIRECT_URI, authorizationRequest.getRedirectUri()); + } + + ClientRegistration getClientRegistration() { + return ClientRegistration.withRegistrationId(OidcRegistrationProperty.REGISTRATION_ID.getProperty()) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .build(); + } +}