Polish Max Sessions on WebFlux

This commit changes the PreventLoginServerMaximumSessionsExceededHandler to invalidate the WebSession in addition to throwing the error, this is needed otherwise the session would still be saved with the security context. It also changes the SessionRegistryWebSession to first perform the operation on the delegate and then invoke the needed method on the ReactiveSessionRegistry

Issue gh-6192
This commit is contained in:
Marcus Hert Da Coregio 2024-02-27 11:12:41 -03:00
parent c639d0a514
commit a5ce8ae87f
7 changed files with 56 additions and 28 deletions

View File

@ -2143,19 +2143,23 @@ public class ServerHttpSecurity {
@Override
public Mono<Void> changeSessionId() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.changeSessionId().thenReturn(information))
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry.saveSessionInformation(information);
});
return this.session.changeSessionId()
.then(Mono.defer(
() -> SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry
.saveSessionInformation(information);
})));
}
@Override
public Mono<Void> invalidate() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.invalidate());
return this.session.invalidate()
.then(Mono.defer(() -> SessionRegistryWebFilter.this.sessionRegistry
.removeSessionInformation(currentId)))
.then();
}
@Override

View File

@ -67,6 +67,7 @@ import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.server.session.DefaultWebSessionManager;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -95,14 +96,19 @@ public class SessionManagementSpecTests {
ResponseCookie firstLoginSessionCookie = loginReturningCookie(data);
// second login should fail
this.client.mutateWith(csrf())
ResponseCookie secondLoginSessionCookie = this.client.mutateWith(csrf())
.post()
.uri("/login")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromFormData(data))
.exchange()
.expectHeader()
.location("/login?error");
.location("/login?error")
.returnResult(Void.class)
.getResponseCookies()
.getFirst("SESSION");
assertThat(secondLoginSessionCookie).isNull();
// first login should still be valid
this.client.mutateWith(csrf())

View File

@ -81,8 +81,8 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler
}
}
}
return this.maximumSessionsExceededHandler
.handle(new MaximumSessionsContext(authentication, registeredSessions, maximumSessions));
return this.maximumSessionsExceededHandler.handle(new MaximumSessionsContext(authentication,
registeredSessions, maximumSessions, currentSession));
});
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.util.List;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.session.ReactiveSessionInformation;
import org.springframework.web.server.WebSession;
public final class MaximumSessionsContext {
@ -29,11 +30,14 @@ public final class MaximumSessionsContext {
private final int maximumSessionsAllowed;
private final WebSession currentSession;
public MaximumSessionsContext(Authentication authentication, List<ReactiveSessionInformation> sessions,
int maximumSessionsAllowed) {
int maximumSessionsAllowed, WebSession currentSession) {
this.authentication = authentication;
this.sessions = sessions;
this.maximumSessionsAllowed = maximumSessionsAllowed;
this.currentSession = currentSession;
}
public Authentication getAuthentication() {
@ -48,4 +52,8 @@ public final class MaximumSessionsContext {
return this.maximumSessionsAllowed;
}
public WebSession getCurrentSession() {
return this.currentSession;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S
@Override
public Mono<Void> handle(MaximumSessionsContext context) {
return Mono
.error(new SessionAuthenticationException("Maximum sessions of " + context.getMaximumSessionsAllowed()
+ " for authentication '" + context.getAuthentication().getName() + "' exceeded"));
return context.getCurrentSession()
.invalidate()
.then(Mono.defer(() -> Mono.error(new SessionAuthenticationException("Maximum sessions exceeded"))));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L));
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2), 2);
List.of(session1, session2), 2, null);
this.handler.handle(context).block();
@ -72,7 +72,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
given(session1.invalidate()).willReturn(Mono.empty());
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2, session3), 2);
List.of(session1, session2, session3), 2, null);
this.handler.handle(context).block();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,13 +19,19 @@ package org.springframework.security.web.server.authentication.session;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.web.authentication.session.SessionAuthenticationException;
import org.springframework.security.web.server.authentication.MaximumSessionsContext;
import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler;
import org.springframework.web.server.WebSession;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link PreventLoginServerMaximumSessionsExceededHandler}.
@ -35,13 +41,17 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
class PreventLoginServerMaximumSessionsExceededHandlerTests {
@Test
void handleWhenInvokedThenThrowsSessionAuthenticationException() {
void handleWhenInvokedThenInvalidateWebSessionAndThrowsSessionAuthenticationException() {
PreventLoginServerMaximumSessionsExceededHandler handler = new PreventLoginServerMaximumSessionsExceededHandler();
WebSession webSession = mock();
given(webSession.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(TestAuthentication.authenticatedUser(),
Collections.emptyList(), 1);
assertThatExceptionOfType(SessionAuthenticationException.class)
.isThrownBy(() -> handler.handle(context).block())
.withMessage("Maximum sessions of 1 for authentication 'user' exceeded");
Collections.emptyList(), 1, webSession);
StepVerifier.create(handler.handle(context)).expectErrorSatisfies((ex) -> {
assertThat(ex).isInstanceOf(SessionAuthenticationException.class);
assertThat(ex.getMessage()).isEqualTo("Maximum sessions exceeded");
}).verify();
verify(webSession).invalidate();
}
}