diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt index 8f09f5589a..effedad6f7 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt @@ -20,6 +20,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher import org.springframework.web.server.ServerWebExchange +import org.springframework.web.server.WebFilter /** * Configures [ServerHttpSecurity] using a [ServerHttpSecurity Kotlin DSL][ServerHttpSecurityDsl]. @@ -89,6 +90,81 @@ class ServerHttpSecurityDsl(private val http: ServerHttpSecurity, private val in this.http.securityMatcher(securityMatcher) } + /** + * Adds a [WebFilter] at a specific position. + * + * Example: + * + * ``` + * @EnableWebFluxSecurity + * class SecurityConfig { + * + * @Bean + * fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + * return http { + * addFilterAt(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + * } + * } + * } + * ``` + * + * @param webFilter the [WebFilter] to add + * @param order the place to insert the [WebFilter] + */ + fun addFilterAt(webFilter: WebFilter, order: SecurityWebFiltersOrder) { + this.http.addFilterAt(webFilter, order) + } + + /** + * Adds a [WebFilter] before specific position. + * + * Example: + * + * ``` + * @EnableWebFluxSecurity + * class SecurityConfig { + * + * @Bean + * fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + * return http { + * addFilterBefore(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + * } + * } + * } + * ``` + * + * @param webFilter the [WebFilter] to add + * @param order the place before which to insert the [WebFilter] + */ + fun addFilterBefore(webFilter: WebFilter, order: SecurityWebFiltersOrder) { + this.http.addFilterBefore(webFilter, order) + } + + /** + * Adds a [WebFilter] after specific position. + * + * Example: + * + * ``` + * @EnableWebFluxSecurity + * class SecurityConfig { + * + * @Bean + * fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + * return http { + * addFilterAfter(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + * } + * } + * } + * ``` + * + * @param webFilter the [WebFilter] to add + * @param order the place after which to insert the [WebFilter] + */ + fun addFilterAfter(webFilter: WebFilter, order: SecurityWebFiltersOrder) { + this.http.addFilterAfter(webFilter, order) + } + /** * Enables form based authentication. * diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt index 0b6135741f..2128c9c9d4 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt @@ -16,6 +16,7 @@ package org.springframework.security.config.web.server +import org.assertj.core.api.Assertions.assertThat import org.junit.Rule import org.junit.Test import org.springframework.beans.factory.annotation.Autowired @@ -26,6 +27,7 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux import org.springframework.security.config.test.SpringTestRule import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter import org.springframework.security.web.server.header.StrictTransportSecurityServerHttpHeadersWriter import org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter @@ -33,6 +35,10 @@ import org.springframework.security.web.server.header.XXssProtectionServerHttpHe import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.reactive.config.EnableWebFlux +import org.springframework.web.server.ServerWebExchange +import org.springframework.web.server.WebFilter +import org.springframework.web.server.WebFilterChain +import reactor.core.publisher.Mono /** * Tests for [ServerHttpSecurityDsl] @@ -123,4 +129,74 @@ class ServerHttpSecurityDslTests { } } } + + @Test + fun `add filter at applies custom at specified filter position`() { + this.spring.register(CustomWebFilterAtConfig::class.java).autowire() + val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java) + val filters = filterChain.webFilters.collectList().block() + + assertThat(filters).last().isExactlyInstanceOf(CustomWebFilter::class.java) + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomWebFilterAtConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + addFilterAt(CustomWebFilter(), SecurityWebFiltersOrder.LAST) + } + } + } + + @Test + fun `add filter before applies custom before specified filter position`() { + this.spring.register(CustomWebFilterBeforeConfig::class.java).autowire() + val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java) + val filters: List>? = filterChain.webFilters.map { it.javaClass }.collectList().block() + + assertThat(filters).containsSubsequence( + CustomWebFilter::class.java, + SecurityContextServerWebExchangeWebFilter::class.java + ) + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomWebFilterBeforeConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + addFilterBefore(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + } + } + } + + @Test + fun `add filter after applies custom after specified filter position`() { + this.spring.register(CustomWebFilterAfterConfig::class.java).autowire() + val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java) + val filters: List>? = filterChain.webFilters.map { it.javaClass }.collectList().block() + + assertThat(filters).containsSubsequence( + SecurityContextServerWebExchangeWebFilter::class.java, + CustomWebFilter::class.java + ) + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomWebFilterAfterConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + addFilterAfter(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + } + } + } + + class CustomWebFilter : WebFilter { + override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono = Mono.empty() + } }