mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-05-30 08:42:13 +00:00
HttpSessionOAuth2AuthorizationRequestRepository handle multiple OAuth2AuthorizationRequest per session
Fixes gh-5110
This commit is contained in:
parent
7e6ed52603
commit
59cef7d339
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2017 the original author or authors.
|
* Copyright 2002-2018 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -16,11 +16,14 @@
|
|||||||
package org.springframework.security.oauth2.client.web;
|
package org.springframework.security.oauth2.client.web;
|
||||||
|
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import javax.servlet.http.HttpSession;
|
import javax.servlet.http.HttpSession;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An implementation of an {@link AuthorizationRequestRepository} that stores
|
* An implementation of an {@link AuthorizationRequestRepository} that stores
|
||||||
@ -39,9 +42,10 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
|
|||||||
@Override
|
@Override
|
||||||
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
|
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
|
||||||
Assert.notNull(request, "request cannot be null");
|
Assert.notNull(request, "request cannot be null");
|
||||||
HttpSession session = request.getSession(false);
|
Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
|
||||||
if (session != null) {
|
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||||
return (OAuth2AuthorizationRequest) session.getAttribute(this.sessionAttributeName);
|
if (authorizationRequests != null) {
|
||||||
|
return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -55,7 +59,9 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
|
|||||||
this.removeAuthorizationRequest(request);
|
this.removeAuthorizationRequest(request);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest);
|
Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
|
||||||
|
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
|
||||||
|
authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -63,8 +69,26 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
|
|||||||
Assert.notNull(request, "request cannot be null");
|
Assert.notNull(request, "request cannot be null");
|
||||||
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
|
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
|
||||||
if (authorizationRequest != null) {
|
if (authorizationRequest != null) {
|
||||||
request.getSession().removeAttribute(this.sessionAttributeName);
|
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||||
|
authorizationRequests.remove(authorizationRequest.getState());
|
||||||
}
|
}
|
||||||
return authorizationRequest;
|
return authorizationRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
|
||||||
|
return this.getAuthorizationRequests(request, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
|
||||||
|
Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
|
||||||
|
HttpSession session = request.getSession(createSession);
|
||||||
|
if (session != null) {
|
||||||
|
authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
|
||||||
|
if (authorizationRequests == null) {
|
||||||
|
authorizationRequests = new HashMap<>();
|
||||||
|
session.setAttribute(this.sessionAttributeName, authorizationRequests);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return authorizationRequests;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,9 +22,11 @@ import org.powermock.modules.junit4.PowerMockRunner;
|
|||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
|
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
|
||||||
@ -44,8 +46,10 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
|
public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
OAuth2AuthorizationRequest authorizationRequest =
|
OAuth2AuthorizationRequest authorizationRequest =
|
||||||
this.authorizationRequestRepository.loadAuthorizationRequest(new MockHttpServletRequest());
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
|
||||||
assertThat(authorizationRequest).isNull();
|
assertThat(authorizationRequest).isNull();
|
||||||
}
|
}
|
||||||
@ -54,15 +58,69 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
|
public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn("state-1234");
|
||||||
|
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||||
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
|
||||||
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
|
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gh-5110
|
||||||
|
@Test
|
||||||
|
public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
|
||||||
|
String state1 = "state-1122";
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest1 = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest1.getState()).thenReturn(state1);
|
||||||
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
|
||||||
|
|
||||||
|
String state2 = "state-3344";
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest2 = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest2.getState()).thenReturn(state2);
|
||||||
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
|
||||||
|
|
||||||
|
String state3 = "state-5566";
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest3 = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest3.getState()).thenReturn(state3);
|
||||||
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
|
||||||
|
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, state1);
|
||||||
|
OAuth2AuthorizationRequest loadedAuthorizationRequest1 =
|
||||||
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
|
||||||
|
|
||||||
|
request.removeParameter(OAuth2ParameterNames.STATE);
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, state2);
|
||||||
|
OAuth2AuthorizationRequest loadedAuthorizationRequest2 =
|
||||||
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
|
||||||
|
|
||||||
|
request.removeParameter(OAuth2ParameterNames.STATE);
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, state3);
|
||||||
|
OAuth2AuthorizationRequest loadedAuthorizationRequest3 =
|
||||||
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenThrowIllegalArgumentException() {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn("state-1234");
|
||||||
|
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||||
|
authorizationRequest, request, new MockHttpServletResponse());
|
||||||
|
|
||||||
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
|
public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||||
@ -75,13 +133,22 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
|
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void saveAuthorizationRequestWhenStateNullThenThrowIllegalArgumentException() {
|
||||||
|
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||||
|
mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), new MockHttpServletResponse());
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveAuthorizationRequestWhenNotNullThenSaved() {
|
public void saveAuthorizationRequestWhenNotNullThenSaved() {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
|
||||||
|
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn("state-1234");
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||||
authorizationRequest, request, new MockHttpServletResponse());
|
authorizationRequest, request, new MockHttpServletResponse());
|
||||||
|
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||||
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
|
||||||
@ -92,12 +159,17 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
public void saveAuthorizationRequestWhenNullThenRemoved() {
|
public void saveAuthorizationRequestWhenNullThenRemoved() {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn("state-1234");
|
||||||
|
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest( // Save
|
this.authorizationRequestRepository.saveAuthorizationRequest( // Save
|
||||||
authorizationRequest, request, response);
|
authorizationRequest, request, response);
|
||||||
|
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest( // Null value removes
|
this.authorizationRequestRepository.saveAuthorizationRequest( // Null value removes
|
||||||
null, request, response);
|
null, request, response);
|
||||||
|
|
||||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||||
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||||
|
|
||||||
@ -113,10 +185,14 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
public void removeAuthorizationRequestWhenSavedThenRemoved() {
|
public void removeAuthorizationRequestWhenSavedThenRemoved() {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn("state-1234");
|
||||||
|
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||||
authorizationRequest, request, response);
|
authorizationRequest, request, response);
|
||||||
|
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
OAuth2AuthorizationRequest removedAuthorizationRequest =
|
OAuth2AuthorizationRequest removedAuthorizationRequest =
|
||||||
this.authorizationRequestRepository.removeAuthorizationRequest(request);
|
this.authorizationRequestRepository.removeAuthorizationRequest(request);
|
||||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||||
@ -129,6 +205,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||||||
@Test
|
@Test
|
||||||
public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
|
public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
|
||||||
|
|
||||||
OAuth2AuthorizationRequest removedAuthorizationRequest =
|
OAuth2AuthorizationRequest removedAuthorizationRequest =
|
||||||
this.authorizationRequestRepository.removeAuthorizationRequest(request);
|
this.authorizationRequestRepository.removeAuthorizationRequest(request);
|
||||||
|
@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web;
|
|||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
@ -26,8 +27,6 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
|
|||||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
|
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
||||||
|
|
||||||
import javax.servlet.FilterChain;
|
import javax.servlet.FilterChain;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
@ -153,7 +152,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSavedInSession() throws Exception {
|
public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSaved() throws Exception {
|
||||||
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
|
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
|
||||||
"/" + this.registration2.getRegistrationId();
|
"/" + this.registration2.getRegistrationId();
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
@ -162,31 +161,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
||||||
new HttpSessionOAuth2AuthorizationRequestRepository();
|
mock(AuthorizationRequestRepository.class);
|
||||||
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
|
verify(authorizationRequestRepository).saveAuthorizationRequest(
|
||||||
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
|
any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
assertThat(authorizationRequest).isNotNull();
|
|
||||||
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
|
|
||||||
this.registration2.getProviderDetails().getAuthorizationUri());
|
|
||||||
assertThat(authorizationRequest.getGrantType()).isEqualTo(
|
|
||||||
this.registration2.getAuthorizationGrantType());
|
|
||||||
assertThat(authorizationRequest.getResponseType()).isEqualTo(
|
|
||||||
OAuth2AuthorizationResponseType.CODE);
|
|
||||||
assertThat(authorizationRequest.getClientId()).isEqualTo(
|
|
||||||
this.registration2.getClientId());
|
|
||||||
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
|
|
||||||
"http://localhost/login/oauth2/code/registration-2");
|
|
||||||
assertThat(authorizationRequest.getScopes()).isEqualTo(
|
|
||||||
this.registration2.getScopes());
|
|
||||||
assertThat(authorizationRequest.getState()).isNotNull();
|
|
||||||
assertThat(authorizationRequest.getAdditionalParameters()
|
|
||||||
.get(OAuth2ParameterNames.REGISTRATION_ID)).isEqualTo(this.registration2.getRegistrationId());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -206,7 +188,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSavedInSession() throws Exception {
|
public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSaved() throws Exception {
|
||||||
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
|
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
|
||||||
"/" + this.registration3.getRegistrationId();
|
"/" + this.registration3.getRegistrationId();
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
@ -215,16 +197,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
||||||
new HttpSessionOAuth2AuthorizationRequestRepository();
|
mock(AuthorizationRequestRepository.class);
|
||||||
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verifyZeroInteractions(filterChain);
|
verifyZeroInteractions(filterChain);
|
||||||
|
verify(authorizationRequestRepository, times(0)).saveAuthorizationRequest(
|
||||||
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
|
any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
assertThat(authorizationRequest).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -255,14 +235,19 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
||||||
new HttpSessionOAuth2AuthorizationRequestRepository();
|
mock(AuthorizationRequestRepository.class);
|
||||||
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
verifyZeroInteractions(filterChain);
|
ArgumentCaptor<OAuth2AuthorizationRequest> authorizationRequestArgCaptor =
|
||||||
|
ArgumentCaptor.forClass(OAuth2AuthorizationRequest.class);
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
|
verifyZeroInteractions(filterChain);
|
||||||
|
verify(authorizationRequestRepository).saveAuthorizationRequest(
|
||||||
|
authorizationRequestArgCaptor.capture(), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
|
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestArgCaptor.getValue();
|
||||||
|
|
||||||
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
|
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
|
||||||
this.registration2.getRedirectUriTemplate());
|
this.registration2.getRedirectUriTemplate());
|
||||||
|
@ -200,15 +200,16 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
||||||
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
||||||
|
String state = "state";
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.setServletPath(requestUri);
|
request.setServletPath(requestUri);
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
request.addParameter(OAuth2ParameterNames.STATE, state);
|
||||||
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
this.setUpAuthorizationRequest(request, response, this.registration2);
|
this.setUpAuthorizationRequest(request, response, this.registration2, state);
|
||||||
this.setUpAuthenticationResult(this.registration2);
|
this.setUpAuthenticationResult(this.registration2);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
@ -219,15 +220,16 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||||||
@Test
|
@Test
|
||||||
public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
|
public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
|
||||||
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
||||||
|
String state = "state";
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.setServletPath(requestUri);
|
request.setServletPath(requestUri);
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
request.addParameter(OAuth2ParameterNames.STATE, state);
|
||||||
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
this.setUpAuthorizationRequest(request, response, this.registration1, state);
|
||||||
this.setUpAuthenticationResult(this.registration1);
|
this.setUpAuthenticationResult(this.registration1);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
@ -248,15 +250,16 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||||||
this.filter.setAuthenticationManager(this.authenticationManager);
|
this.filter.setAuthenticationManager(this.authenticationManager);
|
||||||
|
|
||||||
String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
|
String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
|
||||||
|
String state = "state";
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.setServletPath(requestUri);
|
request.setServletPath(requestUri);
|
||||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
request.addParameter(OAuth2ParameterNames.STATE, state);
|
||||||
|
|
||||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
FilterChain filterChain = mock(FilterChain.class);
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
this.setUpAuthorizationRequest(request, response, this.registration2);
|
this.setUpAuthorizationRequest(request, response, this.registration2, state);
|
||||||
this.setUpAuthenticationResult(this.registration2);
|
this.setUpAuthenticationResult(this.registration2);
|
||||||
|
|
||||||
this.filter.doFilter(request, response, filterChain);
|
this.filter.doFilter(request, response, filterChain);
|
||||||
@ -285,8 +288,9 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
||||||
ClientRegistration registration) {
|
ClientRegistration registration, String state) {
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
||||||
|
when(authorizationRequest.getState()).thenReturn(state);
|
||||||
Map<String, Object> additionalParameters = new HashMap<>();
|
Map<String, Object> additionalParameters = new HashMap<>();
|
||||||
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
||||||
when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
||||||
|
@ -250,7 +250,7 @@ public class OAuth2LoginApplicationTests {
|
|||||||
|
|
||||||
HtmlElement errorElement = page.getBody().getFirstByXPath("p");
|
HtmlElement errorElement = page.getBody().getFirstByXPath("p");
|
||||||
assertThat(errorElement).isNotNull();
|
assertThat(errorElement).isNotNull();
|
||||||
assertThat(errorElement.asText()).contains("invalid_state_parameter");
|
assertThat(errorElement.asText()).contains("authorization_request_not_found");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user