diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java index 222753a2fe..7437119800 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java @@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web.server; import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -160,7 +159,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolver Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - String baseUrl = UriComponentsBuilder.fromHttpRequest(new ServerHttpRequestDecorator(request)) + String baseUrl = UriComponentsBuilder.fromUri(request.getURI()) .replacePath(request.getPath().contextPath().value()) .replaceQuery(null) .build() diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java index f0289dbb61..86a722d18c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java @@ -90,6 +90,20 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { return this.resolver.resolve(exchange).block(); } + @Test + public void resolveWhenForwardedHeadersClientRegistrationFoundThenWorks() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( + Mono.just(this.registration)); + ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/oauth2/authorization/id").header("X-Forwarded-Host", "evil.com")); + + OAuth2AuthorizationRequest request = this.resolver.resolve(exchange).block(); + + assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.*?&" + + "redirect_uri=/login/oauth2/code/registration-id"); + } + @Test public void resolveWhenAuthorizationRequestWithValidPkceClientThenResolves() { when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(