diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java index db478be3ab..131f88bf58 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.io.Serializable; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import org.springframework.util.Assert; @@ -78,4 +79,20 @@ public class DefaultSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPri this.registrationId = registrationId; } + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + if (!(object instanceof DefaultSaml2AuthenticatedPrincipal that)) { + return false; + } + return Objects.equals(this.name, that.name) && Objects.equals(this.registrationId, that.registrationId); + } + + @Override + public int hashCode() { + return Objects.hash(this.name, this.registrationId); + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java index 1b9b33fb7c..4ea06059a6 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,4 +81,24 @@ public class DefaultSaml2AuthenticatedPrincipalTests { assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate); } + // gh-15346 + @Test + public void whenUsedAsKeyInMapThenRetrievableAcrossSerialization() { + Map valuesByPrincipal = new LinkedHashMap<>(); + DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", Map.of()); + valuesByPrincipal.put(principal, 1); + principal = new DefaultSaml2AuthenticatedPrincipal("user", Map.of()); + assertThat(valuesByPrincipal.get(principal)).isEqualTo(1); + principal = new DefaultSaml2AuthenticatedPrincipal("user", Map.of()); + principal.setRelyingPartyRegistrationId("id"); + assertThat(valuesByPrincipal.get(principal)).isNull(); + valuesByPrincipal.put(principal, 2); + principal = new DefaultSaml2AuthenticatedPrincipal("user", Map.of()); + principal.setRelyingPartyRegistrationId("id"); + assertThat(valuesByPrincipal.get(principal)).isEqualTo(2); + principal = new DefaultSaml2AuthenticatedPrincipal("USER", Map.of()); + principal.setRelyingPartyRegistrationId("id"); + assertThat(valuesByPrincipal.get(principal)).isNull(); + } + }