diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 36ee156e92..0385408bc2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -82,6 +82,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts; +import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; @@ -113,10 +114,10 @@ import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -429,6 +430,8 @@ public class Saml2LoginConfigurerTests { private void performSaml2Login(String expected) throws IOException, ServletException { // setup authentication parameters + this.request.setRequestURI("/login/saml2/sso/registration-id"); + this.request.setServletPath("/login/saml2/sso/registration-id"); this.request.setParameter("SAMLResponse", Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes())); // perform test @@ -821,9 +824,7 @@ public class Saml2LoginConfigurerTests { .assertingPartyDetails((party) -> party.verificationX509Credentials( (c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) .build(); - RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); - given(repository.findByRegistrationId(anyString())).willReturn(registration); - return repository; + return spy(new InMemoryRelyingPartyRegistrationRepository(registration)); } }