diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java index b3a512ee55..a8aa973e2a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.util.Assert; @@ -134,6 +135,33 @@ public final class OAuth2AuthorizeRequest { this.authorizedClient = authorizedClient; } + /** + * Sets the name of the {@code Principal} (to be) associated to the authorized client. + * + * @since 5.3 + * @param principalName the name of the {@code Principal} (to be) associated to the authorized client + * @return the {@link Builder} + */ + public Builder principal(String principalName) { + return principal(createAuthentication(principalName)); + } + + private static Authentication createAuthentication(final String principalName) { + Assert.hasText(principalName, "principalName cannot be empty"); + + return new AbstractAuthenticationToken(null) { + @Override + public Object getCredentials() { + return ""; + } + + @Override + public Object getPrincipal() { + return principalName; + } + }; + } + /** * Sets the {@code Principal} (to be) associated to the authorized client. * diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java index d04c4b5beb..37c7e89eef 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link OAuth2AuthorizeRequest}. @@ -58,6 +60,13 @@ public class OAuth2AuthorizeRequestTests { .hasMessage("principal cannot be null"); } + @Test + public void withClientRegistrationIdWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal((String) null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principalName cannot be empty"); + } + @Test public void withClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) @@ -89,4 +98,15 @@ public class OAuth2AuthorizeRequestTests { assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); assertThat(authorizeRequest.getAttributes()).contains(entry("name1", "value1"), entry("name2", "value2")); } + + @Test + public void withClientRegistrationIdWhenPrincipalNameProvidedThenPrincipalCreated() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal("principalName") + .build(); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isNull(); + assertThat(authorizeRequest.getPrincipal().getName()).isEqualTo("principalName"); + } }