diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt index 639130f17f..d308429c08 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher import org.springframework.web.server.ServerWebExchange import org.springframework.web.server.WebFilter @@ -65,6 +66,7 @@ operator fun ServerHttpSecurity.invoke(httpConfiguration: ServerHttpSecurityDsl. class ServerHttpSecurityDsl(private val http: ServerHttpSecurity, private val init: ServerHttpSecurityDsl.() -> Unit) { var authenticationManager: ReactiveAuthenticationManager? = null + var securityContextRepository: ServerSecurityContextRepository? = null /** * Allows configuring the [ServerHttpSecurity] to only be invoked when matching the @@ -718,6 +720,7 @@ class ServerHttpSecurityDsl(private val http: ServerHttpSecurity, private val in internal fun build(): SecurityWebFilterChain { init() authenticationManager?.also { this.http.authenticationManager(authenticationManager) } + securityContextRepository?.also { this.http.securityContextRepository(securityContextRepository) } return this.http.build() } } diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt index 52dd70ad6b..825278d9d8 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt @@ -35,7 +35,9 @@ import org.springframework.security.config.test.SpringTestContextExtension import org.springframework.security.core.Authentication import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter +import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter import org.springframework.security.web.server.header.StrictTransportSecurityServerHttpHeadersWriter import org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter @@ -251,4 +253,31 @@ class ServerHttpSecurityDslTests { return Mono.empty() } } + + @Test + fun `security context repository when configured in DSL then used`() { + this.spring.register(SecurityContextRepositoryConfig::class.java).autowire() + mockkObject(SecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY) + every { + SecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.load(any()) + } returns Mono.empty() + this.client.get().uri("/").exchange() + verify(exactly = 1) { SecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.load(any()) } + } + + @Configuration + @EnableWebFlux + @EnableWebFluxSecurity + open class SecurityContextRepositoryConfig { + companion object { + val SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + securityContextRepository = SECURITY_CONTEXT_REPOSITORY + } + } + } }