diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverter.java new file mode 100644 index 0000000000..80d62789ec --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import java.util.function.Function; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriComponentsBuilder; + +import reactor.core.publisher.Mono; + + +/** + * Converts from a {@link ServerWebExchange} to an {@link OAuth2LoginAuthenticationToken} that can be authenticated. The + * converter does not validate any errors it only performs a conversion. + * @author Rob Winch + * @since 5.1 + * @see org.springframework.security.web.server.authentication.AuthenticationWebFilter#setAuthenticationConverter(Function) + */ +public class ServerOAuth2LoginAuthenticationTokenConverter implements + Function> { + + static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; + + static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; + + private ReactiveAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ReactiveAuthorizationRequestRepository(); + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + public ServerOAuth2LoginAuthenticationTokenConverter( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + } + + /** + * Sets the {@link ReactiveAuthorizationRequestRepository} to be used. The default is + * {@link WebSessionOAuth2ReactiveAuthorizationRequestRepository}. + * @param authorizationRequestRepository the repository to use. + */ + public void setAuthorizationRequestRepository( + ReactiveAuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + } + + @Override + public Mono apply(ServerWebExchange serverWebExchange) { + return this.authorizationRequestRepository.removeAuthorizationRequest(serverWebExchange) + .switchIfEmpty(oauth2AuthenticationException(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE)) + .flatMap(authorizationRequest -> authenticationRequest(serverWebExchange, authorizationRequest)); + } + + private Mono oauth2AuthenticationException(String errorCode) { + return Mono.defer(() -> { + OAuth2Error oauth2Error = new OAuth2Error(errorCode); + return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); + }); + } + + private Mono authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) { + return Mono.just(authorizationRequest) + .map(OAuth2AuthorizationRequest::getAdditionalParameters) + .flatMap(additionalParams -> { + String id = (String) additionalParams.get(OAuth2ParameterNames.REGISTRATION_ID); + if (id == null) { + return oauth2AuthenticationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); + } + return this.clientRegistrationRepository.findByRegistrationId(id); + }) + .switchIfEmpty(oauth2AuthenticationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE)) + .map(clientRegistration -> { + OAuth2AuthorizationResponse authorizationResponse = convert(exchange); + OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( + clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); + return authenticationRequest; + }); + } + + private static OAuth2AuthorizationResponse convert(ServerWebExchange exchange) { + MultiValueMap queryParams = exchange.getRequest() + .getQueryParams(); + String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) + .query(null) + .build() + .toUriString(); + + return OAuth2AuthorizationResponseUtils.convert(queryParams, redirectUri); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverterTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverterTest.java new file mode 100644 index 0000000000..8ceb5b1727 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/ServerOAuth2LoginAuthenticationTokenConverterTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import reactor.core.publisher.Mono; + +/** + * @author Rob Winch + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class ServerOAuth2LoginAuthenticationTokenConverterTest { + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + @Mock + private ReactiveAuthorizationRequestRepository authorizationRequestRepository; + + private String clientRegistrationId = "github"; + + private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) + .redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://github.com/login/oauth/authorize") + .tokenUri("https://github.com/login/oauth/access_token") + .userInfoUri("https://api.github.com/user") + .userNameAttributeName("id") + .clientName("GitHub") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + + private OAuth2AuthorizationRequest.Builder authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state("state") + .additionalParameters(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId)); + + private final MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/"); + + private ServerOAuth2LoginAuthenticationTokenConverter converter; + + @Before + public void setup() { + this.converter = new ServerOAuth2LoginAuthenticationTokenConverter(this.clientRegistrationRepository); + this.converter.setAuthorizationRequestRepository(this.authorizationRequestRepository); + } + + @Test + public void applyWhenAuthorizationRequestEmptyThenOAuth2AuthenticationException() { + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.empty()); + + assertThatThrownBy(() -> applyConverter()) + .isInstanceOf(OAuth2AuthenticationException.class); + } + + @Test + public void applyWhenAdditionalParametersMissingThenOAuth2AuthenticationException() { + this.authorizationRequest.additionalParameters(Collections.emptyMap()); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); + + assertThatThrownBy(() -> applyConverter()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining(ServerOAuth2LoginAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); + } + + @Test + public void applyWhenClientRegistrationMissingThenOAuth2AuthenticationException() { + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); + + assertThatThrownBy(() -> applyConverter()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining(ServerOAuth2LoginAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); + } + + @Test + public void applyWhenCodeParameterNotFoundThenErrorCode() { + this.request.queryParam(OAuth2ParameterNames.ERROR, "error"); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); + + assertThat(applyConverter().getAuthorizationExchange().getAuthorizationResponse().getError().getErrorCode()) + .isEqualTo("error"); + } + + @Test + public void applyWhenCodeParameterFoundThenCode() { + this.request.queryParam(OAuth2ParameterNames.CODE, "code"); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2LoginAuthenticationToken result = applyConverter(); + + OAuth2AuthorizationResponse exchange = result + .getAuthorizationExchange().getAuthorizationResponse(); + assertThat(exchange.getError()).isNull(); + assertThat(exchange.getCode()).isEqualTo("code"); + } + + private OAuth2LoginAuthenticationToken applyConverter() { + MockServerWebExchange exchange = MockServerWebExchange.from(this.request); + return (OAuth2LoginAuthenticationToken) this.converter.apply(exchange).block(); + } +}