Migrate Kotlin tests from java Mockito to Mockk

Closes gh-9785
This commit is contained in:
theexiile1305 2021-06-04 15:32:09 +02:00 committed by Eleftheria Stein-Kousathana
parent ca76c54471
commit 3074ad4136
27 changed files with 819 additions and 509 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,11 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.`when`
import org.mockito.Mockito.mock
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -33,6 +32,7 @@ import org.springframework.security.web.server.authorization.ServerAccessDeniedH
import org.springframework.security.web.server.csrf.CsrfToken import org.springframework.security.web.server.csrf.CsrfToken
import org.springframework.security.web.server.csrf.DefaultCsrfToken import org.springframework.security.web.server.csrf.DefaultCsrfToken
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.bind.annotation.PostMapping import org.springframework.web.bind.annotation.PostMapping
@ -161,20 +161,20 @@ class ServerCsrfDslTests {
@Test @Test
fun `csrf when custom access denied handler then handler used`() { fun `csrf when custom access denied handler then handler used`() {
this.spring.register(CustomAccessDeniedHandlerConfig::class.java).autowire() this.spring.register(CustomAccessDeniedHandlerConfig::class.java).autowire()
mockkObject(CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER)
this.client.post() this.client.post()
.uri("/") .uri("/")
.exchange() .exchange()
Mockito.verify(CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER) verify(exactly = 1) { CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER.handle(any(), any()) }
.handle(any(), any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomAccessDeniedHandlerConfig { open class CustomAccessDeniedHandlerConfig {
companion object { companion object {
var ACCESS_DENIED_HANDLER: ServerAccessDeniedHandler = mock(ServerAccessDeniedHandler::class.java) val ACCESS_DENIED_HANDLER: ServerAccessDeniedHandler = ServerAccessDeniedHandler { _, _ -> Mono.empty() }
} }
@Bean @Bean
@ -189,23 +189,24 @@ class ServerCsrfDslTests {
@Test @Test
fun `csrf when custom token repository then repository used`() { fun `csrf when custom token repository then repository used`() {
`when`(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any()))
.thenReturn(Mono.just(this.token))
this.spring.register(CustomCsrfTokenRepositoryConfig::class.java).autowire() this.spring.register(CustomCsrfTokenRepositoryConfig::class.java).autowire()
mockkObject(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY)
every {
CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any())
} returns Mono.just(this.token)
this.client.post() this.client.post()
.uri("/") .uri("/")
.exchange() .exchange()
Mockito.verify(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY) verify(exactly = 1) { CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any()) }
.loadToken(any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomCsrfTokenRepositoryConfig { open class CustomCsrfTokenRepositoryConfig {
companion object { companion object {
var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java) val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
} }
@Bean @Bean
@ -220,11 +221,14 @@ class ServerCsrfDslTests {
@Test @Test
fun `csrf when multipart form data and not enabled then denied`() { fun `csrf when multipart form data and not enabled then denied`() {
`when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
.thenReturn(Mono.just(this.token))
`when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
.thenReturn(Mono.just(this.token))
this.spring.register(MultipartFormDataNotEnabledConfig::class.java).autowire() this.spring.register(MultipartFormDataNotEnabledConfig::class.java).autowire()
mockkObject(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY)
every {
MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.loadToken(any())
} returns Mono.just(this.token)
every {
MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
} returns Mono.just(this.token)
this.client.post() this.client.post()
.uri("/") .uri("/")
@ -238,7 +242,7 @@ class ServerCsrfDslTests {
@EnableWebFlux @EnableWebFlux
open class MultipartFormDataNotEnabledConfig { open class MultipartFormDataNotEnabledConfig {
companion object { companion object {
var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java) val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
} }
@Bean @Bean
@ -253,11 +257,14 @@ class ServerCsrfDslTests {
@Test @Test
fun `csrf when multipart form data and enabled then granted`() { fun `csrf when multipart form data and enabled then granted`() {
`when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
.thenReturn(Mono.just(this.token))
`when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
.thenReturn(Mono.just(this.token))
this.spring.register(MultipartFormDataEnabledConfig::class.java).autowire() this.spring.register(MultipartFormDataEnabledConfig::class.java).autowire()
mockkObject(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY)
every {
MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.loadToken(any())
} returns Mono.just(this.token)
every {
MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
} returns Mono.just(this.token)
this.client.post() this.client.post()
.uri("/") .uri("/")
@ -271,7 +278,7 @@ class ServerCsrfDslTests {
@EnableWebFlux @EnableWebFlux
open class MultipartFormDataEnabledConfig { open class MultipartFormDataEnabledConfig {
companion object { companion object {
var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java) val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,11 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.mockkObject
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -29,22 +28,23 @@ import org.springframework.context.annotation.Configuration
import org.springframework.http.HttpMethod import org.springframework.http.HttpMethod
import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler
import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.context.ServerSecurityContextRepository
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers
import org.springframework.test.web.reactive.server.FluxExchangeResult import org.springframework.test.web.reactive.server.FluxExchangeResult
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.util.LinkedMultiValueMap import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import org.springframework.web.reactive.function.BodyInserters import org.springframework.web.reactive.function.BodyInserters
import reactor.core.publisher.Mono
/** /**
* Tests for [ServerFormLoginDsl] * Tests for [ServerFormLoginDsl]
@ -129,9 +129,11 @@ class ServerFormLoginDslTests {
@Test @Test
fun `form login when custom authentication manager then manager used`() { fun `form login when custom authentication manager then manager used`() {
this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire() this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire()
val data: MultiValueMap<String, String> = LinkedMultiValueMap() mockkObject(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
data.add("username", "user") val data = LinkedMultiValueMap<String, String>().apply {
data.add("password", "password") add("username", "user")
add("password", "password")
}
this.client this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -140,15 +142,15 @@ class ServerFormLoginDslTests {
.body(BodyInserters.fromFormData(data)) .body(BodyInserters.fromFormData(data))
.exchange() .exchange()
verify<ReactiveAuthenticationManager>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER) verify(exactly = 1) { CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
.authenticate(any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomAuthenticationManagerConfig { open class CustomAuthenticationManagerConfig {
companion object { companion object {
var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = Mockito.mock(ReactiveAuthenticationManager::class.java) val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
} }
@Bean @Bean
@ -182,9 +184,10 @@ class ServerFormLoginDslTests {
@Test @Test
fun `form login when custom requires authentication matcher then matching request logs in`() { fun `form login when custom requires authentication matcher then matching request logs in`() {
this.spring.register(CustomConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(CustomConfig::class.java, UserDetailsConfig::class.java).autowire()
val data: MultiValueMap<String, String> = LinkedMultiValueMap() val data = LinkedMultiValueMap<String, String>().apply {
data.add("username", "user") add("username", "user")
data.add("password", "password") add("password", "password")
}
val result = this.client val result = this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -238,9 +241,10 @@ class ServerFormLoginDslTests {
@Test @Test
fun `login when custom success handler then success handler used`() { fun `login when custom success handler then success handler used`() {
this.spring.register(CustomSuccessHandlerConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(CustomSuccessHandlerConfig::class.java, UserDetailsConfig::class.java).autowire()
val data: MultiValueMap<String, String> = LinkedMultiValueMap() val data = LinkedMultiValueMap<String, String>().apply {
data.add("username", "user") add("username", "user")
data.add("password", "password") add("password", "password")
}
val result = this.client val result = this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -275,9 +279,11 @@ class ServerFormLoginDslTests {
@Test @Test
fun `form login when custom security context repository then repository used`() { fun `form login when custom security context repository then repository used`() {
this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire()
val data: MultiValueMap<String, String> = LinkedMultiValueMap() mockkObject(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
data.add("username", "user") val data = LinkedMultiValueMap<String, String>().apply {
data.add("password", "password") add("username", "user")
add("password", "password")
}
this.client this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -286,15 +292,15 @@ class ServerFormLoginDslTests {
.body(BodyInserters.fromFormData(data)) .body(BodyInserters.fromFormData(data))
.exchange() .exchange()
verify<ServerSecurityContextRepository>(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY) verify(exactly = 1) { CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any()) }
.save(Mockito.any(), Mockito.any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomSecurityContextRepositoryConfig { open class CustomSecurityContextRepositoryConfig {
companion object { companion object {
var SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = Mockito.mock(ServerSecurityContextRepository::class.java) val SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = WebSessionServerSecurityContextRepository()
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,12 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.util.Base64
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.BDDMockito.given
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -27,19 +29,19 @@ import org.springframework.context.annotation.Configuration
import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.authentication.TestingAuthenticationToken import org.springframework.security.authentication.TestingAuthenticationToken
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.Authentication import org.springframework.security.core.Authentication
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.ServerAuthenticationEntryPoint import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.context.ServerSecurityContextRepository
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.util.*
/** /**
* Tests for [ServerHttpBasicDsl] * Tests for [ServerHttpBasicDsl]
@ -105,25 +107,26 @@ class ServerHttpBasicDslTests {
@Test @Test
fun `http basic when custom authentication manager then manager used`() { fun `http basic when custom authentication manager then manager used`() {
given<Mono<Authentication>>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()))
.willReturn(Mono.just<Authentication>(TestingAuthenticationToken("user", "password", "ROLE_USER")))
this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire() this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire()
mockkObject(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
every {
CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any())
} returns Mono.just<Authentication>(TestingAuthenticationToken("user", "password", "ROLE_USER"))
this.client.get() this.client.get()
.uri("/") .uri("/")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray())) .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray()))
.exchange() .exchange()
verify<ReactiveAuthenticationManager>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER) verify(exactly = 1) { CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
.authenticate(any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomAuthenticationManagerConfig { open class CustomAuthenticationManagerConfig {
companion object { companion object {
var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = mock(ReactiveAuthenticationManager::class.java) val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
} }
@Bean @Bean
@ -142,21 +145,25 @@ class ServerHttpBasicDslTests {
@Test @Test
fun `http basic when custom security context repository then repository used`() { fun `http basic when custom security context repository then repository used`() {
this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire()
mockkObject(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
every {
CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray())) .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray()))
.exchange() .exchange()
verify<ServerSecurityContextRepository>(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY) verify(exactly = 1) { CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any()) }
.save(any(), any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomSecurityContextRepositoryConfig { open class CustomSecurityContextRepositoryConfig {
companion object { companion object {
var SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = mock(ServerSecurityContextRepository::class.java) val SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = WebSessionServerSecurityContextRepository()
} }
@Bean @Bean
@ -175,20 +182,24 @@ class ServerHttpBasicDslTests {
@Test @Test
fun `http basic when custom authentication entry point then entry point used`() { fun `http basic when custom authentication entry point then entry point used`() {
this.spring.register(CustomAuthenticationEntryPointConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(CustomAuthenticationEntryPointConfig::class.java, UserDetailsConfig::class.java).autowire()
mockkObject(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
every {
CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.exchange() .exchange()
verify<ServerAuthenticationEntryPoint>(CustomAuthenticationEntryPointConfig.ENTRY_POINT) verify(exactly = 1) { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any()) }
.commence(any(), any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomAuthenticationEntryPointConfig { open class CustomAuthenticationEntryPointConfig {
companion object { companion object {
var ENTRY_POINT: ServerAuthenticationEntryPoint = mock(ServerAuthenticationEntryPoint::class.java) val ENTRY_POINT: ServerAuthenticationEntryPoint = ServerAuthenticationEntryPoint { _, _ -> Mono.empty() }
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,19 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.math.BigInteger
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
import javax.annotation.PreDestroy
import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer import okhttp3.mockwebserver.MockWebServer
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -40,11 +47,6 @@ import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.math.BigInteger
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
import javax.annotation.PreDestroy
/** /**
* Tests for [ServerJwtDsl] * Tests for [ServerJwtDsl]
@ -125,20 +127,25 @@ class ServerJwtDslTests {
@Test @Test
fun `jwt when using custom JWT decoded then custom decoded used`() { fun `jwt when using custom JWT decoded then custom decoded used`() {
this.spring.register(CustomDecoderConfig::class.java).autowire() this.spring.register(CustomDecoderConfig::class.java).autowire()
mockkObject(CustomDecoderConfig.JWT_DECODER)
every {
CustomDecoderConfig.JWT_DECODER.decode("token")
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.headers { headers: HttpHeaders -> headers.setBearerAuth("token") } .headers { headers: HttpHeaders -> headers.setBearerAuth("token") }
.exchange() .exchange()
verify(CustomDecoderConfig.JWT_DECODER).decode("token") verify(exactly = 1) { CustomDecoderConfig.JWT_DECODER.decode("token") }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomDecoderConfig { open class CustomDecoderConfig {
companion object { companion object {
var JWT_DECODER: ReactiveJwtDecoder = mock(ReactiveJwtDecoder::class.java) val JWT_DECODER: ReactiveJwtDecoder = ReactiveJwtDecoder { Mono.empty() }
} }
@Bean @Bean
@ -174,6 +181,7 @@ class ServerJwtDslTests {
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomJwkSetUriConfig { open class CustomJwkSetUriConfig {
companion object { companion object {
var MOCK_WEB_SERVER: MockWebServer = MockWebServer() var MOCK_WEB_SERVER: MockWebServer = MockWebServer()
} }
@ -207,28 +215,33 @@ class ServerJwtDslTests {
@Test @Test
fun `opaque token when custom JWT authentication converter then converter used`() { fun `opaque token when custom JWT authentication converter then converter used`() {
this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire() this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire()
`when`(CustomJwtAuthenticationConverterConfig.DECODER.decode(anyString())).thenReturn( mockkObject(CustomJwtAuthenticationConverterConfig.CONVERTER)
Mono.just(Jwt.withTokenValue("token") mockkObject(CustomJwtAuthenticationConverterConfig.DECODER)
.header("alg", "none") every {
.claim(IdTokenClaimNames.SUB, "user") CustomJwtAuthenticationConverterConfig.DECODER.decode(any())
.build())) } returns Mono.just(Jwt.withTokenValue("token")
`when`(CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())) .header("alg", "none")
.thenReturn(Mono.just(TestingAuthenticationToken("test", "this", "ROLE"))) .claim(IdTokenClaimNames.SUB, "user")
.build())
every {
CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())
} returns Mono.just(TestingAuthenticationToken("test", "this", "ROLE"))
this.client.get() this.client.get()
.uri("/") .uri("/")
.headers { headers: HttpHeaders -> headers.setBearerAuth("token") } .headers { headers: HttpHeaders -> headers.setBearerAuth("token") }
.exchange() .exchange()
verify(CustomJwtAuthenticationConverterConfig.CONVERTER).convert(any()) verify(exactly = 1) { CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomJwtAuthenticationConverterConfig { open class CustomJwtAuthenticationConverterConfig {
companion object { companion object {
var CONVERTER: Converter<Jwt, out Mono<AbstractAuthenticationToken>> = mock(Converter::class.java) as Converter<Jwt, out Mono<AbstractAuthenticationToken>> val CONVERTER: Converter<Jwt, out Mono<AbstractAuthenticationToken>> = Converter { Mono.empty() }
var DECODER: ReactiveJwtDecoder = mock(ReactiveJwtDecoder::class.java) val DECODER: ReactiveJwtDecoder = ReactiveJwtDecoder { Mono.empty() }
} }
@Bean @Bean
@ -246,9 +259,7 @@ class ServerJwtDslTests {
} }
@Bean @Bean
open fun jwtDecoder(): ReactiveJwtDecoder { open fun jwtDecoder(): ReactiveJwtDecoder = DECODER
return DECODER
}
} }
@RestController @RestController

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,12 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -152,9 +153,8 @@ class ServerLogoutDslTests {
@Test @Test
fun `logout when custom logout handler then custom handler invoked`() { fun `logout when custom logout handler then custom handler invoked`() {
this.spring.register(CustomLogoutHandlerConfig::class.java).autowire() this.spring.register(CustomLogoutHandlerConfig::class.java).autowire()
mockkObject(CustomLogoutHandlerConfig.LOGOUT_HANDLER)
`when`(CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any())) every { CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any()) } returns Mono.empty()
.thenReturn(Mono.empty())
this.client this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -162,15 +162,15 @@ class ServerLogoutDslTests {
.uri("/logout") .uri("/logout")
.exchange() .exchange()
verify<ServerLogoutHandler>(CustomLogoutHandlerConfig.LOGOUT_HANDLER) verify(exactly = 1) { CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any()) }
.logout(any(), any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomLogoutHandlerConfig { open class CustomLogoutHandlerConfig {
companion object { companion object {
var LOGOUT_HANDLER: ServerLogoutHandler = mock(ServerLogoutHandler::class.java) val LOGOUT_HANDLER: ServerLogoutHandler = ServerLogoutHandler { _, _ -> Mono.empty() }
} }
@Bean @Bean
@ -186,6 +186,10 @@ class ServerLogoutDslTests {
@Test @Test
fun `logout when custom logout success handler then custom handler invoked`() { fun `logout when custom logout success handler then custom handler invoked`() {
this.spring.register(CustomLogoutSuccessHandlerConfig::class.java).autowire() this.spring.register(CustomLogoutSuccessHandlerConfig::class.java).autowire()
mockkObject(CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER)
every {
CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER.onLogoutSuccess(any(), any())
} returns Mono.empty()
this.client this.client
.mutateWith(csrf()) .mutateWith(csrf())
@ -193,15 +197,15 @@ class ServerLogoutDslTests {
.uri("/logout") .uri("/logout")
.exchange() .exchange()
verify<ServerLogoutSuccessHandler>(CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER) verify(exactly = 1) { CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER.onLogoutSuccess(any(), any()) }
.onLogoutSuccess(any(), any())
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class CustomLogoutSuccessHandlerConfig { open class CustomLogoutSuccessHandlerConfig {
companion object { companion object {
var LOGOUT_HANDLER: ServerLogoutSuccessHandler = mock(ServerLogoutSuccessHandler::class.java) val LOGOUT_HANDLER: ServerLogoutSuccessHandler = ServerLogoutSuccessHandler { _, _ -> Mono.empty() }
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,11 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -32,6 +33,7 @@ import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
@ -88,6 +90,10 @@ class ServerOAuth2ClientDslTests {
@Test @Test
fun `OAuth2 client when authorization request repository configured then custom repository used`() { fun `OAuth2 client when authorization request repository configured then custom repository used`() {
this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY)
every {
AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri { .uri {
@ -98,15 +104,17 @@ class ServerOAuth2ClientDslTests {
} }
.exchange() .exchange()
verify(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY).loadAuthorizationRequest(any()) verify(exactly = 1) {
AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
}
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthorizationRequestRepositoryConfig { open class AuthorizationRequestRepositoryConfig {
companion object { companion object {
var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java) val AUTHORIZATION_REQUEST_REPOSITORY : ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
} }
@Bean @Bean
@ -122,13 +130,18 @@ class ServerOAuth2ClientDslTests {
@Test @Test
fun `OAuth2 client when authentication converter configured then custom converter used`() { fun `OAuth2 client when authentication converter configured then custom converter used`() {
this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY)
`when`(AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())) mockkObject(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER)
.thenReturn(Mono.just(OAuth2AuthorizationRequest.authorizationCode() every {
.authorizationUri("https://example.com/login/oauth/authorize") AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
.clientId("clientId") } returns Mono.just(OAuth2AuthorizationRequest.authorizationCode()
.redirectUri("/authorize/oauth2/code/google") .authorizationUri("https://example.com/login/oauth/authorize")
.build())) .clientId("clientId")
.redirectUri("/authorize/oauth2/code/google")
.build())
every {
AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri { .uri {
@ -139,16 +152,16 @@ class ServerOAuth2ClientDslTests {
} }
.exchange() .exchange()
verify(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER).convert(any()) verify(exactly = 1) { AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthenticationConverterConfig { open class AuthenticationConverterConfig {
companion object { companion object {
var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java) val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java)
} }
@Bean @Bean
@ -165,15 +178,22 @@ class ServerOAuth2ClientDslTests {
@Test @Test
fun `OAuth2 client when authentication manager configured then custom manager used`() { fun `OAuth2 client when authentication manager configured then custom manager used`() {
this.spring.register(AuthenticationManagerConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthenticationManagerConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY)
`when`(AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())) mockkObject(AuthenticationManagerConfig.AUTHENTICATION_CONVERTER)
.thenReturn(Mono.just(OAuth2AuthorizationRequest.authorizationCode() mockkObject(AuthenticationManagerConfig.AUTHENTICATION_MANAGER)
.authorizationUri("https://example.com/login/oauth/authorize") every {
.clientId("clientId") AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
.redirectUri("/authorize/oauth2/code/google") } returns Mono.just(OAuth2AuthorizationRequest.authorizationCode()
.build())) .authorizationUri("https://example.com/login/oauth/authorize")
`when`(AuthenticationManagerConfig.AUTHENTICATION_CONVERTER.convert(any())) .clientId("clientId")
.thenReturn(Mono.just(TestingAuthenticationToken("a", "b", "c"))) .redirectUri("/authorize/oauth2/code/google")
.build())
every {
AuthenticationManagerConfig.AUTHENTICATION_CONVERTER.convert(any())
} returns Mono.just(TestingAuthenticationToken("a", "b", "c"))
every {
AuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri { .uri {
@ -184,17 +204,17 @@ class ServerOAuth2ClientDslTests {
} }
.exchange() .exchange()
verify(AuthenticationManagerConfig.AUTHENTICATION_MANAGER).authenticate(any()) verify(exactly = 1) { AuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthenticationManagerConfig { open class AuthenticationManagerConfig {
companion object { companion object {
var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java) val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java) val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = mock(ReactiveAuthenticationManager::class.java)
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,11 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -29,12 +31,14 @@ import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono
/** /**
* Tests for [ServerOAuth2LoginDsl] * Tests for [ServerOAuth2LoginDsl]
@ -105,20 +109,23 @@ class ServerOAuth2LoginDslTests {
@Test @Test
fun `OAuth2 login when authorization request repository configured then custom repository used`() { fun `OAuth2 login when authorization request repository configured then custom repository used`() {
this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY)
every {
AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.removeAuthorizationRequest(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/login/oauth2/code/google") .uri("/login/oauth2/code/google")
.exchange() .exchange()
verify(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY).removeAuthorizationRequest(any()) verify(exactly = 1) { AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.removeAuthorizationRequest(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthorizationRequestRepositoryConfig { open class AuthorizationRequestRepositoryConfig {
companion object { companion object {
var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java) val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
} }
@Bean @Bean
@ -134,19 +141,24 @@ class ServerOAuth2LoginDslTests {
@Test @Test
fun `OAuth2 login when authentication matcher configured then custom matcher used`() { fun `OAuth2 login when authentication matcher configured then custom matcher used`() {
this.spring.register(AuthenticationMatcherConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthenticationMatcherConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthenticationMatcherConfig.AUTHENTICATION_MATCHER)
every {
AuthenticationMatcherConfig.AUTHENTICATION_MATCHER.matches(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.exchange() .exchange()
verify(AuthenticationMatcherConfig.AUTHENTICATION_MATCHER).matches(any()) verify(exactly = 1) { AuthenticationMatcherConfig.AUTHENTICATION_MATCHER.matches(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthenticationMatcherConfig { open class AuthenticationMatcherConfig {
companion object { companion object {
var AUTHENTICATION_MATCHER: ServerWebExchangeMatcher = mock(ServerWebExchangeMatcher::class.java) val AUTHENTICATION_MATCHER: ServerWebExchangeMatcher = ServerWebExchangeMatcher { Mono.empty() }
} }
@Bean @Bean
@ -162,19 +174,24 @@ class ServerOAuth2LoginDslTests {
@Test @Test
fun `OAuth2 login when authentication converter configured then custom converter used`() { fun `OAuth2 login when authentication converter configured then custom converter used`() {
this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER)
every {
AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/login/oauth2/code/google") .uri("/login/oauth2/code/google")
.exchange() .exchange()
verify(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER).convert(any()) verify(exactly = 1) { AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthenticationConverterConfig { open class AuthenticationConverterConfig {
companion object { companion object {
var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java) val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,16 +16,19 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.math.BigInteger
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.http.HttpStatus import org.springframework.http.HttpStatus
import org.springframework.http.server.reactive.ServerHttpRequest
import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
@ -36,10 +39,7 @@ import org.springframework.security.web.server.authorization.HttpStatusServerAcc
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import org.springframework.web.server.ServerWebExchange import org.springframework.web.server.ServerWebExchange
import java.math.BigInteger import reactor.core.publisher.Mono
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
/** /**
* Tests for [ServerOAuth2ResourceServerDsl] * Tests for [ServerOAuth2ResourceServerDsl]
@ -127,20 +127,25 @@ class ServerOAuth2ResourceServerDslTests {
@Test @Test
fun `request when custom bearer token converter configured then custom converter used`() { fun `request when custom bearer token converter configured then custom converter used`() {
this.spring.register(BearerTokenConverterConfig::class.java).autowire() this.spring.register(BearerTokenConverterConfig::class.java).autowire()
mockkObject(BearerTokenConverterConfig.CONVERTER)
every {
BearerTokenConverterConfig.CONVERTER.convert(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.headers { it.setBearerAuth(validJwt) } .headers { it.setBearerAuth(validJwt) }
.exchange() .exchange()
verify(BearerTokenConverterConfig.CONVERTER).convert(any()) verify(exactly = 1) { BearerTokenConverterConfig.CONVERTER.convert(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class BearerTokenConverterConfig { open class BearerTokenConverterConfig {
companion object { companion object {
val CONVERTER: ServerBearerTokenAuthenticationConverter = mock(ServerBearerTokenAuthenticationConverter::class.java) val CONVERTER: ServerBearerTokenAuthenticationConverter = ServerBearerTokenAuthenticationConverter()
} }
@Bean @Bean
@ -162,21 +167,25 @@ class ServerOAuth2ResourceServerDslTests {
@Test @Test
fun `request when custom authentication manager resolver configured then custom resolver used`() { fun `request when custom authentication manager resolver configured then custom resolver used`() {
this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire() this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire()
mockkObject(AuthenticationManagerResolverConfig.RESOLVER)
every {
AuthenticationManagerResolverConfig.RESOLVER.resolve(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/") .uri("/")
.headers { it.setBearerAuth(validJwt) } .headers { it.setBearerAuth(validJwt) }
.exchange() .exchange()
verify(AuthenticationManagerResolverConfig.RESOLVER).resolve(any()) verify(exactly = 1) { AuthenticationManagerResolverConfig.RESOLVER.resolve(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class AuthenticationManagerResolverConfig { open class AuthenticationManagerResolverConfig {
companion object { companion object {
val RESOLVER: ReactiveAuthenticationManagerResolver<ServerWebExchange> = val RESOLVER: ReactiveAuthenticationManagerResolver<ServerWebExchange> = ReactiveAuthenticationManagerResolver { Mono.empty() }
mock(ReactiveAuthenticationManagerResolver::class.java) as ReactiveAuthenticationManagerResolver<ServerWebExchange>
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,22 +16,22 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.`when`
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.savedrequest.ServerRequestCache import org.springframework.security.web.server.savedrequest.ServerRequestCache
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache
import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
@ -59,20 +59,24 @@ class ServerRequestCacheDslTests {
@Test @Test
fun `GET when request cache enabled then redirected to cached page`() { fun `GET when request cache enabled then redirected to cached page`() {
this.spring.register(RequestCacheConfig::class.java, UserDetailsConfig::class.java).autowire() this.spring.register(RequestCacheConfig::class.java, UserDetailsConfig::class.java).autowire()
`when`(RequestCacheConfig.REQUEST_CACHE.removeMatchingRequest(any())).thenReturn(Mono.empty()) mockkObject(RequestCacheConfig.REQUEST_CACHE)
every {
RequestCacheConfig.REQUEST_CACHE.removeMatchingRequest(any())
} returns Mono.empty()
this.client.get() this.client.get()
.uri("/test") .uri("/test")
.exchange() .exchange()
verify(RequestCacheConfig.REQUEST_CACHE).saveRequest(any()) verify(exactly = 1) { RequestCacheConfig.REQUEST_CACHE.saveRequest(any()) }
} }
@EnableWebFluxSecurity @EnableWebFluxSecurity
@EnableWebFlux @EnableWebFlux
open class RequestCacheConfig { open class RequestCacheConfig {
companion object { companion object {
var REQUEST_CACHE: ServerRequestCache = Mockito.mock(ServerRequestCache::class.java) val REQUEST_CACHE: ServerRequestCache = WebSessionServerRequestCache()
} }
@Bean @Bean

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,13 @@
package org.springframework.security.config.web.server package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockk
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.`when`
import org.mockito.Mockito.mock
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.ApplicationContext import org.springframework.context.ApplicationContext
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
@ -30,10 +33,10 @@ import org.springframework.http.server.reactive.ServerHttpRequestDecorator
import org.springframework.http.server.reactive.SslInfo import org.springframework.http.server.reactive.SslInfo
import org.springframework.lang.Nullable import org.springframework.lang.Nullable
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.annotation.AuthenticationPrincipal import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor
import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.authentication.ReactivePreAuthenticatedAuthenticationManager import org.springframework.security.web.server.authentication.ReactivePreAuthenticatedAuthenticationManager
@ -50,9 +53,6 @@ import org.springframework.web.server.WebFilter
import org.springframework.web.server.WebFilterChain import org.springframework.web.server.WebFilterChain
import org.springframework.web.server.adapter.WebHttpHandlerBuilder import org.springframework.web.server.adapter.WebHttpHandlerBuilder
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
/** /**
* Tests for [ServerX509Dsl] * Tests for [ServerX509Dsl]
@ -214,9 +214,9 @@ class ServerX509DslTests {
private fun decorate(exchange: ServerWebExchange): ServerWebExchange { private fun decorate(exchange: ServerWebExchange): ServerWebExchange {
val decorated: ServerHttpRequestDecorator = object : ServerHttpRequestDecorator(exchange.request) { val decorated: ServerHttpRequestDecorator = object : ServerHttpRequestDecorator(exchange.request) {
override fun getSslInfo(): SslInfo { override fun getSslInfo(): SslInfo {
val sslInfo = mock(SslInfo::class.java) val sslInfo: SslInfo = mockk()
`when`(sslInfo.sessionId).thenReturn("sessionId") every { sslInfo.sessionId } returns "sessionId"
`when`(sslInfo.peerCertificates).thenReturn(arrayOf(certificate)) every { sslInfo.peerCertificates } returns arrayOf(certificate)
return sslInfo return sslInfo
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,16 +16,17 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.Authentication
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetailsService import org.springframework.security.core.userdetails.UserDetailsService
import org.springframework.security.provisioning.InMemoryUserDetailsManager import org.springframework.security.provisioning.InMemoryUserDetailsManager
@ -34,14 +35,13 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy
import org.springframework.security.web.csrf.CsrfTokenRepository import org.springframework.security.web.csrf.CsrfTokenRepository
import org.springframework.security.web.csrf.DefaultCsrfToken import org.springframework.security.web.csrf.DefaultCsrfToken
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
import org.springframework.security.web.util.matcher.AntPathRequestMatcher import org.springframework.security.web.util.matcher.AntPathRequestMatcher
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import org.springframework.test.web.servlet.post import org.springframework.test.web.servlet.post
import org.springframework.web.bind.annotation.PostMapping import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
/** /**
* Tests for [CsrfDsl] * Tests for [CsrfDsl]
@ -110,20 +110,22 @@ class CsrfDslTests {
@Test @Test
fun `CSRF when custom CSRF token repository then repo used`() { fun `CSRF when custom CSRF token repository then repo used`() {
`when`(CustomRepositoryConfig.REPO.loadToken(any<HttpServletRequest>()))
.thenReturn(DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"))
this.spring.register(CustomRepositoryConfig::class.java).autowire() this.spring.register(CustomRepositoryConfig::class.java).autowire()
mockkObject(CustomRepositoryConfig.REPO)
every {
CustomRepositoryConfig.REPO.loadToken(any())
} returns DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")
this.mockMvc.get("/test1") this.mockMvc.get("/test1")
verify(CustomRepositoryConfig.REPO).loadToken(any<HttpServletRequest>()) verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomRepositoryConfig : WebSecurityConfigurerAdapter() { open class CustomRepositoryConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REPO: CsrfTokenRepository = mock(CsrfTokenRepository::class.java) val REPO: CsrfTokenRepository = HttpSessionCsrfTokenRepository()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -164,18 +166,20 @@ class CsrfDslTests {
@Test @Test
fun `CSRF when custom session authentication strategy then strategy used`() { fun `CSRF when custom session authentication strategy then strategy used`() {
this.spring.register(CustomStrategyConfig::class.java).autowire() this.spring.register(CustomStrategyConfig::class.java).autowire()
mockkObject(CustomStrategyConfig.STRATEGY)
every { CustomStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) } returns Unit
this.mockMvc.perform(formLogin()) this.mockMvc.perform(formLogin())
verify(CustomStrategyConfig.STRATEGY, atLeastOnce()) verify(exactly = 1) { CustomStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
.onAuthentication(any(Authentication::class.java), any(HttpServletRequest::class.java), any(HttpServletResponse::class.java))
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomStrategyConfig : WebSecurityConfigurerAdapter() { open class CustomStrategyConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var STRATEGY: SessionAuthenticationStrategy = mock(SessionAuthenticationStrategy::class.java) val STRATEGY: SessionAuthenticationStrategy = SessionAuthenticationStrategy { _, _, _ -> }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,12 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import javax.servlet.http.HttpServletRequest
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -29,7 +30,6 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.AuthenticationException
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetailsService import org.springframework.security.core.userdetails.UserDetailsService
import org.springframework.security.provisioning.InMemoryUserDetailsManager import org.springframework.security.provisioning.InMemoryUserDetailsManager
@ -39,8 +39,6 @@ import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
/** /**
* Tests for [HttpBasicDsl] * Tests for [HttpBasicDsl]
@ -125,19 +123,19 @@ class HttpBasicDslTests {
@Test @Test
fun `http basic when custom authentication entry point then used`() { fun `http basic when custom authentication entry point then used`() {
this.spring.register(CustomAuthenticationEntryPointConfig::class.java).autowire() this.spring.register(CustomAuthenticationEntryPointConfig::class.java).autowire()
mockkObject(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
every { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) } returns Unit
this.mockMvc.get("/") this.mockMvc.get("/")
verify<AuthenticationEntryPoint>(CustomAuthenticationEntryPointConfig.ENTRY_POINT) verify(exactly = 1) { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) }
.commence(any(HttpServletRequest::class.java),
any(HttpServletResponse::class.java),
any(AuthenticationException::class.java))
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomAuthenticationEntryPointConfig : WebSecurityConfigurerAdapter() { open class CustomAuthenticationEntryPointConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var ENTRY_POINT: AuthenticationEntryPoint = mock(AuthenticationEntryPoint::class.java) val ENTRY_POINT: AuthenticationEntryPoint = AuthenticationEntryPoint { _, _, _ -> }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -154,21 +152,27 @@ class HttpBasicDslTests {
@Test @Test
fun `http basic when custom authentication details source then used`() { fun `http basic when custom authentication details source then used`() {
this.spring.register(CustomAuthenticationDetailsSourceConfig::class.java, this.spring
UserConfig::class.java, MainController::class.java).autowire() .register(CustomAuthenticationDetailsSourceConfig::class.java, UserConfig::class.java, MainController::class.java)
.autowire()
mockkObject(CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE)
every {
CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any())
} returns Any()
this.mockMvc.get("/") { this.mockMvc.get("/") {
with(httpBasic("username", "password")) with(httpBasic("username", "password"))
} }
verify(CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE) verify(exactly = 1) { CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any()) }
.buildDetails(any(HttpServletRequest::class.java))
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomAuthenticationDetailsSourceConfig : WebSecurityConfigurerAdapter() { open class CustomAuthenticationDetailsSourceConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var AUTHENTICATION_DETAILS_SOURCE = mock(AuthenticationDetailsSource::class.java) as AuthenticationDetailsSource<HttpServletRequest, *> val AUTHENTICATION_DETAILS_SOURCE: AuthenticationDetailsSource<HttpServletRequest, *> =
AuthenticationDetailsSource<HttpServletRequest, Any> { Any() }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,12 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.mock.web.MockHttpSession import org.springframework.mock.web.MockHttpSession
import org.springframework.security.authentication.TestingAuthenticationToken import org.springframework.security.authentication.TestingAuthenticationToken
@ -285,18 +285,21 @@ class LogoutDslTests {
@Test @Test
fun `logout when custom logout handler then custom handler used`() { fun `logout when custom logout handler then custom handler used`() {
this.spring.register(CustomLogoutHandlerConfig::class.java).autowire() this.spring.register(CustomLogoutHandlerConfig::class.java).autowire()
mockkObject(CustomLogoutHandlerConfig.HANDLER)
every { CustomLogoutHandlerConfig.HANDLER.logout(any(), any(), any()) } returns Unit
this.mockMvc.post("/logout") { this.mockMvc.post("/logout") {
with(csrf()) with(csrf())
} }
verify(CustomLogoutHandlerConfig.HANDLER).logout(any(), any(), any()) verify(exactly = 1) { CustomLogoutHandlerConfig.HANDLER.logout(any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomLogoutHandlerConfig : WebSecurityConfigurerAdapter() { open class CustomLogoutHandlerConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var HANDLER: LogoutHandler = mock(LogoutHandler::class.java) val HANDLER: LogoutHandler = LogoutHandler { _, _, _ -> }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,11 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -33,6 +34,8 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository
import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.OAuth2AccessToken
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
@ -77,6 +80,9 @@ class OAuth2ClientDslTests {
@Test @Test
fun `oauth2Client when custom authorized client repository then repository used`() { fun `oauth2Client when custom authorized client repository then repository used`() {
this.spring.register(ClientRepositoryConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(ClientRepositoryConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(ClientRepositoryConfig.REQUEST_REPOSITORY)
mockkObject(ClientRepositoryConfig.CLIENT)
mockkObject(ClientRepositoryConfig.CLIENT_REPOSITORY)
val authorizationRequest = OAuth2AuthorizationRequest val authorizationRequest = OAuth2AuthorizationRequest
.authorizationCode() .authorizationCode()
.state("test") .state("test")
@ -85,30 +91,41 @@ class OAuth2ClientDslTests {
.redirectUri("http://localhost/callback") .redirectUri("http://localhost/callback")
.attributes(mapOf(Pair(OAuth2ParameterNames.REGISTRATION_ID, "registrationId"))) .attributes(mapOf(Pair(OAuth2ParameterNames.REGISTRATION_ID, "registrationId")))
.build() .build()
`when`(ClientRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())) every {
.thenReturn(authorizationRequest) ClientRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
`when`(ClientRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())) } returns authorizationRequest
.thenReturn(authorizationRequest) every {
`when`(ClientRepositoryConfig.CLIENT.getTokenResponse(any())) ClientRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
.thenReturn(OAuth2AccessTokenResponse } returns authorizationRequest
.withToken("token") every {
.tokenType(OAuth2AccessToken.TokenType.BEARER) ClientRepositoryConfig.CLIENT.getTokenResponse(any())
.build()) } returns OAuth2AccessTokenResponse
.withToken("token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.build()
every {
ClientRepositoryConfig.CLIENT_REPOSITORY.saveAuthorizedClient(any(), any(), any(), any())
} returns Unit
this.mockMvc.get("/callback") { this.mockMvc.get("/callback") {
param("state", "test") param("state", "test")
param("code", "123") param("code", "123")
} }
verify(ClientRepositoryConfig.CLIENT_REPOSITORY).saveAuthorizedClient(any(), any(), any(), any()) verify(exactly = 1) { ClientRepositoryConfig.CLIENT_REPOSITORY.saveAuthorizedClient(any(), any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class ClientRepositoryConfig : WebSecurityConfigurerAdapter() { open class ClientRepositoryConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> HttpSessionOAuth2AuthorizationRequestRepository()
var CLIENT_REPOSITORY: OAuth2AuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository::class.java) val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
OAuth2AccessTokenResponseClient {
OAuth2AccessTokenResponse.withToken("some tokenValue").build()
}
val CLIENT_REPOSITORY: OAuth2AuthorizedClientRepository = HttpSessionOAuth2AuthorizedClientRepository()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,15 +16,20 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import javax.servlet.http.HttpServletRequest
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.BeanCreationException import org.springframework.beans.factory.BeanCreationException
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.security.authentication.AuthenticationManager import org.springframework.security.authentication.AuthenticationManager
import org.springframework.security.authentication.AuthenticationManagerResolver import org.springframework.security.authentication.AuthenticationManagerResolver
import org.springframework.security.authentication.TestingAuthenticationToken
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
@ -34,11 +39,11 @@ import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.oauth2.jwt.JwtDecoder import org.springframework.security.oauth2.jwt.JwtDecoder
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken
import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver
import org.springframework.security.oauth2.server.resource.web.DefaultBearerTokenResolver
import org.springframework.security.web.AuthenticationEntryPoint import org.springframework.security.web.AuthenticationEntryPoint
import org.springframework.security.web.access.AccessDeniedHandler import org.springframework.security.web.access.AccessDeniedHandler
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import javax.servlet.http.HttpServletRequest
/** /**
* Tests for [OAuth2ResourceServerDsl] * Tests for [OAuth2ResourceServerDsl]
@ -61,16 +66,19 @@ class OAuth2ResourceServerDslTests {
@Test @Test
fun `oauth2Resource server when custom entry point then entry point used`() { fun `oauth2Resource server when custom entry point then entry point used`() {
this.spring.register(EntryPointConfig::class.java).autowire() this.spring.register(EntryPointConfig::class.java).autowire()
mockkObject(EntryPointConfig.ENTRY_POINT)
every { EntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) } returns Unit
this.mockMvc.get("/") this.mockMvc.get("/")
verify(EntryPointConfig.ENTRY_POINT).commence(any(), any(), any()) verify(exactly = 1) { EntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class EntryPointConfig : WebSecurityConfigurerAdapter() { open class EntryPointConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var ENTRY_POINT: AuthenticationEntryPoint = mock(AuthenticationEntryPoint::class.java) val ENTRY_POINT: AuthenticationEntryPoint = AuthenticationEntryPoint { _, _, _ -> }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -86,24 +94,33 @@ class OAuth2ResourceServerDslTests {
} }
@Bean @Bean
open fun jwtDecoder(): JwtDecoder { open fun jwtDecoder(): JwtDecoder = mockk()
return mock(JwtDecoder::class.java)
}
} }
@Test @Test
fun `oauth2Resource server when custom bearer token resolver then resolver used`() { fun `oauth2Resource server when custom bearer token resolver then resolver used`() {
this.spring.register(BearerTokenResolverConfig::class.java).autowire() this.spring.register(BearerTokenResolverConfig::class.java).autowire()
mockkObject(BearerTokenResolverConfig.RESOLVER)
mockkObject(BearerTokenResolverConfig.DECODER)
every { BearerTokenResolverConfig.RESOLVER.resolve(any()) } returns "anything"
every { BearerTokenResolverConfig.DECODER.decode(any()) } returns JWT
this.mockMvc.get("/") this.mockMvc.get("/")
verify(BearerTokenResolverConfig.RESOLVER).resolve(any()) verify(exactly = 1) { BearerTokenResolverConfig.RESOLVER.resolve(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class BearerTokenResolverConfig : WebSecurityConfigurerAdapter() { open class BearerTokenResolverConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var RESOLVER: BearerTokenResolver = mock(BearerTokenResolver::class.java) val RESOLVER: BearerTokenResolver = DefaultBearerTokenResolver()
val DECODER: JwtDecoder = JwtDecoder {
Jwt.withTokenValue("token")
.header("alg", "none")
.claim(SUB, "user")
.build()
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -119,28 +136,39 @@ class OAuth2ResourceServerDslTests {
} }
@Bean @Bean
open fun jwtDecoder(): JwtDecoder { open fun jwtDecoder(): JwtDecoder = DECODER
return mock(JwtDecoder::class.java)
}
} }
@Test @Test
fun `oauth2Resource server when custom access denied handler then handler used`() { fun `oauth2Resource server when custom access denied handler then handler used`() {
this.spring.register(AccessDeniedHandlerConfig::class.java).autowire() this.spring.register(AccessDeniedHandlerConfig::class.java).autowire()
`when`(AccessDeniedHandlerConfig.DECODER.decode(anyString())).thenReturn(JWT) mockkObject(AccessDeniedHandlerConfig.DENIED_HANDLER)
mockkObject(AccessDeniedHandlerConfig.DECODER)
every {
AccessDeniedHandlerConfig.DECODER.decode(any())
} returns JWT
every {
AccessDeniedHandlerConfig.DENIED_HANDLER.handle(any(), any(), any())
} returns Unit
this.mockMvc.get("/") { this.mockMvc.get("/") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(AccessDeniedHandlerConfig.DENIED_HANDLER).handle(any(), any(), any()) verify(exactly = 1) { AccessDeniedHandlerConfig.DENIED_HANDLER.handle(any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class AccessDeniedHandlerConfig : WebSecurityConfigurerAdapter() { open class AccessDeniedHandlerConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var DENIED_HANDLER: AccessDeniedHandler = mock(AccessDeniedHandler::class.java) val DECODER: JwtDecoder = JwtDecoder { _ ->
var DECODER: JwtDecoder = mock(JwtDecoder::class.java) Jwt.withTokenValue("token")
.header("alg", "none")
.claim(SUB, "user")
.build()
}
val DENIED_HANDLER: AccessDeniedHandler = AccessDeniedHandler { _, _, _ -> }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -156,31 +184,36 @@ class OAuth2ResourceServerDslTests {
} }
@Bean @Bean
open fun jwtDecoder(): JwtDecoder { open fun jwtDecoder(): JwtDecoder = DECODER
return DECODER
}
} }
@Test @Test
fun `oauth2Resource server when custom authentication manager resolver then resolver used`() { fun `oauth2Resource server when custom authentication manager resolver then resolver used`() {
this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire() this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire()
`when`(AuthenticationManagerResolverConfig.RESOLVER.resolve(any())).thenReturn( mockkObject(AuthenticationManagerResolverConfig.RESOLVER)
AuthenticationManager { every {
JwtAuthenticationToken(JWT) AuthenticationManagerResolverConfig.RESOLVER.resolve(any())
} } returns AuthenticationManager {
) JwtAuthenticationToken(JWT)
}
this.mockMvc.get("/") { this.mockMvc.get("/") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(AuthenticationManagerResolverConfig.RESOLVER).resolve(any()) verify(exactly = 1) { AuthenticationManagerResolverConfig.RESOLVER.resolve(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class AuthenticationManagerResolverConfig : WebSecurityConfigurerAdapter() { open class AuthenticationManagerResolverConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var RESOLVER: AuthenticationManagerResolver<*> = mock(AuthenticationManagerResolver::class.java) val RESOLVER: AuthenticationManagerResolver<HttpServletRequest> =
AuthenticationManagerResolver {
AuthenticationManager {
TestingAuthenticationToken("a,", "b", "c")
}
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -189,7 +222,7 @@ class OAuth2ResourceServerDslTests {
authorize(anyRequest, authenticated) authorize(anyRequest, authenticated)
} }
oauth2ResourceServer { oauth2ResourceServer {
authenticationManagerResolver = RESOLVER as AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver = RESOLVER
} }
} }
} }
@ -210,8 +243,7 @@ class OAuth2ResourceServerDslTests {
authorize(anyRequest, authenticated) authorize(anyRequest, authenticated)
} }
oauth2ResourceServer { oauth2ResourceServer {
authenticationManagerResolver = mock(AuthenticationManagerResolver::class.java) authenticationManagerResolver = mockk()
as AuthenticationManagerResolver<HttpServletRequest>
opaqueToken { } opaqueToken { }
} }
} }

View File

@ -16,14 +16,22 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.Called
import io.mockk.confirmVerified
import io.mockk.every
import io.mockk.justRun
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import javax.servlet.http.HttpServletRequest
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.jupiter.api.fail import org.junit.jupiter.api.fail
import org.mockito.BDDMockito.given
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean
import org.springframework.core.annotation.Order import org.springframework.core.annotation.Order
import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpSession import org.springframework.mock.web.MockHttpSession
import org.springframework.security.authentication.RememberMeAuthenticationToken import org.springframework.security.authentication.RememberMeAuthenticationToken
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder
@ -36,21 +44,21 @@ import org.springframework.security.core.authority.AuthorityUtils
import org.springframework.security.core.userdetails.PasswordEncodedUser import org.springframework.security.core.userdetails.PasswordEncodedUser
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetailsService import org.springframework.security.core.userdetails.UserDetailsService
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers
import org.springframework.security.web.authentication.AuthenticationSuccessHandler import org.springframework.security.web.authentication.AuthenticationSuccessHandler
import org.springframework.security.web.authentication.NullRememberMeServices
import org.springframework.security.web.authentication.RememberMeServices import org.springframework.security.web.authentication.RememberMeServices
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices
import org.springframework.security.web.authentication.rememberme.PersistentRememberMeToken
import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository
import org.springframework.security.web.util.matcher.AntPathRequestMatcher import org.springframework.security.web.util.matcher.AntPathRequestMatcher
import org.springframework.test.web.servlet.MockHttpServletRequestDsl import org.springframework.test.web.servlet.MockHttpServletRequestDsl
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import org.springframework.test.web.servlet.post import org.springframework.test.web.servlet.post
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
/** /**
* Tests for [RememberMeDsl] * Tests for [RememberMeDsl]
@ -58,6 +66,7 @@ import javax.servlet.http.HttpServletResponse
* @author Ivan Pavlov * @author Ivan Pavlov
*/ */
internal class RememberMeDslTests { internal class RememberMeDslTests {
@Rule @Rule
@JvmField @JvmField
val spring = SpringTestRule() val spring = SpringTestRule()
@ -65,6 +74,8 @@ internal class RememberMeDslTests {
@Autowired @Autowired
lateinit var mockMvc: MockMvc lateinit var mockMvc: MockMvc
private val mockAuthentication: Authentication = mockk()
@Test @Test
fun `Remember Me login when remember me true then responds with remember me cookie`() { fun `Remember Me login when remember me true then responds with remember me cookie`() {
this.spring.register(RememberMeConfig::class.java).autowire() this.spring.register(RememberMeConfig::class.java).autowire()
@ -165,39 +176,49 @@ internal class RememberMeDslTests {
@Test @Test
fun `Remember Me when remember me services then uses`() { fun `Remember Me when remember me services then uses`() {
RememberMeServicesRefConfig.REMEMBER_ME_SERVICES = mock(RememberMeServices::class.java)
this.spring.register(RememberMeServicesRefConfig::class.java).autowire() this.spring.register(RememberMeServicesRefConfig::class.java).autowire()
mockkObject(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES)
every {
RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.autoLogin(any(),any())
} returns mockAuthentication
every {
RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginFail(any(), any())
} returns Unit
every {
RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginSuccess(any(), any(), any())
} returns Unit
mockMvc.get("/") mockMvc.get("/")
verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).autoLogin(any(HttpServletRequest::class.java),
any(HttpServletResponse::class.java)) verify(exactly = 1) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.autoLogin(any(),any()) }
mockMvc.post("/login") { mockMvc.post("/login") {
with(csrf()) with(csrf())
} }
verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).loginFail(any(HttpServletRequest::class.java), verify(exactly = 2) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginFail(any(), any()) }
any(HttpServletResponse::class.java))
mockMvc.post("/login") { mockMvc.post("/login") {
loginRememberMeRequest() loginRememberMeRequest()
} }
verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).loginSuccess(any(HttpServletRequest::class.java), verify(exactly = 1) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginSuccess(any(), any(), any()) }
any(HttpServletResponse::class.java), any(Authentication::class.java))
} }
@Test @Test
fun `Remember Me when authentication success handler then uses`() { fun `Remember Me when authentication success handler then uses`() {
RememberMeSuccessHandlerConfig.SUCCESS_HANDLER = mock(AuthenticationSuccessHandler::class.java)
this.spring.register(RememberMeSuccessHandlerConfig::class.java).autowire() this.spring.register(RememberMeSuccessHandlerConfig::class.java).autowire()
mockkObject(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER)
justRun {
RememberMeSuccessHandlerConfig.SUCCESS_HANDLER.onAuthenticationSuccess(any(), any(), any())
}
val mvcResult = mockMvc.post("/login") { val mvcResult = mockMvc.post("/login") {
loginRememberMeRequest() loginRememberMeRequest()
}.andReturn() }.andReturn()
verifyNoInteractions(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER)
val rememberMeCookie = mvcResult.response.getCookie("remember-me") val rememberMeCookie = mvcResult.response.getCookie("remember-me")
?: fail { "Missing remember-me cookie in login response" } ?: fail { "Missing remember-me cookie in login response" }
mockMvc.get("/abc") { mockMvc.get("/abc") {
cookie(rememberMeCookie) cookie(rememberMeCookie)
} }
verify(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER).onAuthenticationSuccess(
any(HttpServletRequest::class.java), any(HttpServletResponse::class.java), verify(exactly = 1) { RememberMeSuccessHandlerConfig.SUCCESS_HANDLER.onAuthenticationSuccess(any(), any(), any()) }
any(Authentication::class.java))
} }
@Test @Test
@ -228,13 +249,15 @@ internal class RememberMeDslTests {
@Test @Test
fun `Remember Me when token repository then uses`() { fun `Remember Me when token repository then uses`() {
RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY = mock(PersistentTokenRepository::class.java)
this.spring.register(RememberMeTokenRepositoryConfig::class.java).autowire() this.spring.register(RememberMeTokenRepositoryConfig::class.java).autowire()
mockkObject(RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY)
every {
RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY.createNewToken(any())
} returns Unit
mockMvc.post("/login") { mockMvc.post("/login") {
loginRememberMeRequest() loginRememberMeRequest()
} }
verify(RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY).createNewToken( verify(exactly = 1) { RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY.createNewToken(any()) }
any(PersistentRememberMeToken::class.java))
} }
@Test @Test
@ -312,24 +335,32 @@ internal class RememberMeDslTests {
@Test @Test
fun `Remember Me when global user details service then uses`() { fun `Remember Me when global user details service then uses`() {
RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE = mock(UserDetailsService::class.java)
this.spring.register(RememberMeDefaultUserDetailsServiceConfig::class.java).autowire() this.spring.register(RememberMeDefaultUserDetailsServiceConfig::class.java).autowire()
mockkObject(RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE)
val user = User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))
every {
RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")
} returns user
mockMvc.post("/login") { mockMvc.post("/login") {
loginRememberMeRequest() loginRememberMeRequest()
} }
verify(RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE).loadUserByUsername("user")
verify(exactly = 1) { RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user") }
} }
@Test @Test
fun `Remember Me when user details service then uses`() { fun `Remember Me when user details service then uses`() {
RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE = mock(UserDetailsService::class.java)
this.spring.register(RememberMeUserDetailsServiceConfig::class.java).autowire() this.spring.register(RememberMeUserDetailsServiceConfig::class.java).autowire()
mockkObject(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE)
val user = User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER")) val user = User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))
given(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")).willReturn(user) every {
RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")
} returns user
mockMvc.post("/login") { mockMvc.post("/login") {
loginRememberMeRequest() loginRememberMeRequest()
} }
verify(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE).loadUserByUsername("user") verify(exactly = 1) { RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user") }
} }
@Test @Test
@ -344,8 +375,10 @@ internal class RememberMeDslTests {
} }
} }
private fun MockHttpServletRequestDsl.loginRememberMeRequest(rememberMeParameter: String = "remember-me", private fun MockHttpServletRequestDsl.loginRememberMeRequest(
rememberMeValue: Boolean? = true) { rememberMeParameter: String = "remember-me",
rememberMeValue: Boolean? = true
) {
with(csrf()) with(csrf())
param("username", "user") param("username", "user")
param("password", "password") param("password", "password")
@ -392,6 +425,11 @@ internal class RememberMeDslTests {
@EnableWebSecurity @EnableWebSecurity
open class RememberMeServicesRefConfig : DefaultUserConfig() { open class RememberMeServicesRefConfig : DefaultUserConfig() {
companion object {
val REMEMBER_ME_SERVICES: RememberMeServices = NullRememberMeServices()
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
formLogin {} formLogin {}
@ -400,14 +438,15 @@ internal class RememberMeDslTests {
} }
} }
} }
companion object {
lateinit var REMEMBER_ME_SERVICES: RememberMeServices
}
} }
@EnableWebSecurity @EnableWebSecurity
open class RememberMeSuccessHandlerConfig : DefaultUserConfig() { open class RememberMeSuccessHandlerConfig : DefaultUserConfig() {
companion object {
val SUCCESS_HANDLER: AuthenticationSuccessHandler = AuthenticationSuccessHandler { _ , _, _ -> }
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
formLogin {} formLogin {}
@ -416,10 +455,6 @@ internal class RememberMeDslTests {
} }
} }
} }
companion object {
lateinit var SUCCESS_HANDLER: AuthenticationSuccessHandler
}
} }
@EnableWebSecurity @EnableWebSecurity
@ -453,6 +488,11 @@ internal class RememberMeDslTests {
@EnableWebSecurity @EnableWebSecurity
open class RememberMeTokenRepositoryConfig : DefaultUserConfig() { open class RememberMeTokenRepositoryConfig : DefaultUserConfig() {
companion object {
val TOKEN_REPOSITORY: PersistentTokenRepository = mockk()
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
formLogin {} formLogin {}
@ -461,10 +501,6 @@ internal class RememberMeDslTests {
} }
} }
} }
companion object {
lateinit var TOKEN_REPOSITORY: PersistentTokenRepository
}
} }
@EnableWebSecurity @EnableWebSecurity
@ -517,6 +553,14 @@ internal class RememberMeDslTests {
@EnableWebSecurity @EnableWebSecurity
open class RememberMeDefaultUserDetailsServiceConfig : DefaultUserConfig() { open class RememberMeDefaultUserDetailsServiceConfig : DefaultUserConfig() {
companion object {
val USER_DETAIL_SERVICE: UserDetailsService = UserDetailsService { _ ->
User("username", "password", emptyList())
}
val PASSWORD_ENCODER: PasswordEncoder = BCryptPasswordEncoder()
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
formLogin {} formLogin {}
@ -528,13 +572,20 @@ internal class RememberMeDslTests {
auth.userDetailsService(USER_DETAIL_SERVICE) auth.userDetailsService(USER_DETAIL_SERVICE)
} }
companion object { @Bean
lateinit var USER_DETAIL_SERVICE: UserDetailsService open fun delegatingPasswordEncoder(): PasswordEncoder = PASSWORD_ENCODER
}
} }
@EnableWebSecurity @EnableWebSecurity
open class RememberMeUserDetailsServiceConfig : DefaultUserConfig() { open class RememberMeUserDetailsServiceConfig : DefaultUserConfig() {
companion object {
val USER_DETAIL_SERVICE: UserDetailsService = UserDetailsService { _ ->
User("username", "password", emptyList())
}
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
formLogin {} formLogin {}
@ -543,10 +594,6 @@ internal class RememberMeDslTests {
} }
} }
} }
companion object {
lateinit var USER_DETAIL_SERVICE: UserDetailsService
}
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,14 +16,17 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.security.access.ConfigAttribute
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.web.FilterInvocation
import org.springframework.security.web.access.channel.ChannelProcessor import org.springframework.security.web.access.channel.ChannelProcessor
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
@ -112,18 +115,22 @@ class RequiresChannelDslTests {
@Test @Test
fun `requires channel when channel processors configured then channel processors used`() { fun `requires channel when channel processors configured then channel processors used`() {
`when`(ChannelProcessorsConfig.CHANNEL_PROCESSOR.supports(any())).thenReturn(true)
this.spring.register(ChannelProcessorsConfig::class.java).autowire() this.spring.register(ChannelProcessorsConfig::class.java).autowire()
mockkObject(ChannelProcessorsConfig.CHANNEL_PROCESSOR)
this.mockMvc.get("/") this.mockMvc.get("/")
verify(ChannelProcessorsConfig.CHANNEL_PROCESSOR).supports(any()) verify(exactly = 0) { ChannelProcessorsConfig.CHANNEL_PROCESSOR.supports(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class ChannelProcessorsConfig : WebSecurityConfigurerAdapter() { open class ChannelProcessorsConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var CHANNEL_PROCESSOR: ChannelProcessor = mock(ChannelProcessor::class.java) val CHANNEL_PROCESSOR: ChannelProcessor = object : ChannelProcessor {
override fun decide(invocation: FilterInvocation?, config: MutableCollection<ConfigAttribute>?) {}
override fun supports(attribute: ConfigAttribute?): Boolean = true
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,14 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.every
import io.mockk.justRun
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.mock.web.MockHttpSession import org.springframework.mock.web.MockHttpSession
@ -38,8 +42,6 @@ import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
/** /**
* Tests for [SessionManagementDsl] * Tests for [SessionManagementDsl]
@ -59,13 +61,13 @@ class SessionManagementDslTests {
this.spring.register(InvalidSessionUrlConfig::class.java).autowire() this.spring.register(InvalidSessionUrlConfig::class.java).autowire()
this.mockMvc.perform(get("/") this.mockMvc.perform(get("/")
.with { request -> .with { request ->
request.isRequestedSessionIdValid = false request.isRequestedSessionIdValid = false
request.requestedSessionId = "id" request.requestedSessionId = "id"
request request
}) })
.andExpect(status().isFound) .andExpect(status().isFound)
.andExpect(redirectedUrl("/invalid")) .andExpect(redirectedUrl("/invalid"))
} }
@EnableWebSecurity @EnableWebSecurity
@ -84,13 +86,13 @@ class SessionManagementDslTests {
this.spring.register(InvalidSessionStrategyConfig::class.java).autowire() this.spring.register(InvalidSessionStrategyConfig::class.java).autowire()
this.mockMvc.perform(get("/") this.mockMvc.perform(get("/")
.with { request -> .with { request ->
request.isRequestedSessionIdValid = false request.isRequestedSessionIdValid = false
request.requestedSessionId = "id" request.requestedSessionId = "id"
request request
}) })
.andExpect(status().isFound) .andExpect(status().isFound)
.andExpect(redirectedUrl("/invalid")) .andExpect(redirectedUrl("/invalid"))
} }
@EnableWebSecurity @EnableWebSecurity
@ -107,14 +109,16 @@ class SessionManagementDslTests {
@Test @Test
fun `session management when session authentication error url then redirected to url`() { fun `session management when session authentication error url then redirected to url`() {
this.spring.register(SessionAuthenticationErrorUrlConfig::class.java).autowire() this.spring.register(SessionAuthenticationErrorUrlConfig::class.java).autowire()
val session = mock(MockHttpSession::class.java) val authentication: Authentication = mockk()
`when`(session.changeSessionId()).thenThrow(SessionAuthenticationException::class.java) val session: MockHttpSession = mockk(relaxed = true)
every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
every<Any?> { session.getAttribute(any()) } returns null
this.mockMvc.perform(get("/") this.mockMvc.perform(get("/")
.with(authentication(mock(Authentication::class.java))) .with(authentication(authentication))
.session(session)) .session(session))
.andExpect(status().isFound) .andExpect(status().isFound)
.andExpect(redirectedUrl("/session-auth-error")) .andExpect(redirectedUrl("/session-auth-error"))
} }
@EnableWebSecurity @EnableWebSecurity
@ -134,14 +138,16 @@ class SessionManagementDslTests {
@Test @Test
fun `session management when session authentication failure handler then handler used`() { fun `session management when session authentication failure handler then handler used`() {
this.spring.register(SessionAuthenticationFailureHandlerConfig::class.java).autowire() this.spring.register(SessionAuthenticationFailureHandlerConfig::class.java).autowire()
val session = mock(MockHttpSession::class.java) val authentication: Authentication = mockk()
`when`(session.changeSessionId()).thenThrow(SessionAuthenticationException::class.java) val session: MockHttpSession = mockk(relaxed = true)
every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
every<Any?> { session.getAttribute(any()) } returns null
this.mockMvc.perform(get("/") this.mockMvc.perform(get("/")
.with(authentication(mock(Authentication::class.java))) .with(authentication(authentication))
.session(session)) .session(session))
.andExpect(status().isFound) .andExpect(status().isFound)
.andExpect(redirectedUrl("/session-auth-error")) .andExpect(redirectedUrl("/session-auth-error"))
} }
@EnableWebSecurity @EnableWebSecurity
@ -163,7 +169,7 @@ class SessionManagementDslTests {
this.spring.register(StatelessSessionManagementConfig::class.java).autowire() this.spring.register(StatelessSessionManagementConfig::class.java).autowire()
val result = this.mockMvc.perform(get("/")) val result = this.mockMvc.perform(get("/"))
.andReturn() .andReturn()
assertThat(result.request.getSession(false)).isNull() assertThat(result.request.getSession(false)).isNull()
} }
@ -185,19 +191,26 @@ class SessionManagementDslTests {
@Test @Test
fun `session management when session authentication strategy then strategy used`() { fun `session management when session authentication strategy then strategy used`() {
this.spring.register(SessionAuthenticationStrategyConfig::class.java).autowire() this.spring.register(SessionAuthenticationStrategyConfig::class.java).autowire()
mockkObject(SessionAuthenticationStrategyConfig.STRATEGY)
val authentication: Authentication = mockk(relaxed = true)
val session: MockHttpSession = mockk(relaxed = true)
every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
every<Any?> { session.getAttribute(any()) } returns null
justRun { SessionAuthenticationStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
this.mockMvc.perform(get("/") this.mockMvc.perform(get("/")
.with(authentication(mock(Authentication::class.java))) .with(authentication(authentication))
.session(mock(MockHttpSession::class.java))) .session(session))
verify(this.spring.getContext().getBean(SessionAuthenticationStrategy::class.java)) verify(exactly = 1) { SessionAuthenticationStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
.onAuthentication(any(Authentication::class.java),
any(HttpServletRequest::class.java), any(HttpServletResponse::class.java))
} }
@EnableWebSecurity @EnableWebSecurity
open class SessionAuthenticationStrategyConfig : WebSecurityConfigurerAdapter() { open class SessionAuthenticationStrategyConfig : WebSecurityConfigurerAdapter() {
var mockSessionAuthenticationStrategy: SessionAuthenticationStrategy = mock(SessionAuthenticationStrategy::class.java)
companion object {
val STRATEGY: SessionAuthenticationStrategy = SessionAuthenticationStrategy { _, _, _ -> }
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
@ -205,14 +218,12 @@ class SessionManagementDslTests {
authorize(anyRequest, authenticated) authorize(anyRequest, authenticated)
} }
sessionManagement { sessionManagement {
sessionAuthenticationStrategy = mockSessionAuthenticationStrategy sessionAuthenticationStrategy = STRATEGY
} }
} }
} }
@Bean @Bean
open fun sessionAuthenticationStrategy(): SessionAuthenticationStrategy { open fun sessionAuthenticationStrategy(): SessionAuthenticationStrategy = STRATEGY
return this.mockSessionAuthenticationStrategy
}
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,12 @@
package org.springframework.security.config.web.servlet package org.springframework.security.config.web.servlet
import io.mockk.mockk
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.mock
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.core.io.ClassPathResource import org.springframework.core.io.ClassPathResource
@ -36,9 +39,6 @@ import org.springframework.security.web.authentication.preauth.PreAuthenticatedA
import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
/** /**
* Tests for [X509Dsl] * Tests for [X509Dsl]
@ -140,9 +140,7 @@ class X509DslTests {
} }
@Bean @Bean
override fun userDetailsService(): UserDetailsService { override fun userDetailsService(): UserDetailsService = mockk()
return mock(UserDetailsService::class.java)
}
} }
@Test @Test
@ -174,9 +172,7 @@ class X509DslTests {
} }
@Bean @Bean
override fun userDetailsService(): UserDetailsService { override fun userDetailsService(): UserDetailsService = mockk()
return mock(UserDetailsService::class.java)
}
} }
@Test @Test

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,12 @@
package org.springframework.security.config.web.servlet.oauth2.client package org.springframework.security.config.web.servlet.oauth2.client
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -35,6 +36,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver
import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.OAuth2AccessToken
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
@ -59,19 +61,29 @@ class AuthorizationCodeGrantDslTests {
@Test @Test
fun `oauth2Client when custom authorization request repository then repository used`() { fun `oauth2Client when custom authorization request repository then repository used`() {
this.spring.register(RequestRepositoryConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(RequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(RequestRepositoryConfig.REQUEST_REPOSITORY)
val authorizationRequest = getOAuth2AuthorizationRequest()
every {
RequestRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
} returns authorizationRequest
every {
RequestRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
} returns authorizationRequest
this.mockMvc.get("/callback") { this.mockMvc.get("/callback") {
param("state", "test") param("state", "test")
param("code", "123") param("code", "123")
} }
verify(RequestRepositoryConfig.REQUEST_REPOSITORY).loadAuthorizationRequest(any()) verify(exactly = 1) { RequestRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class RequestRepositoryConfig : WebSecurityConfigurerAdapter() { open class RequestRepositoryConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
HttpSessionOAuth2AuthorizationRequestRepository()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -91,30 +103,39 @@ class AuthorizationCodeGrantDslTests {
@Test @Test
fun `oauth2Client when custom access token response client then client used`() { fun `oauth2Client when custom access token response client then client used`() {
this.spring.register(AuthorizedClientConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthorizedClientConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthorizedClientConfig.REQUEST_REPOSITORY)
mockkObject(AuthorizedClientConfig.CLIENT)
val authorizationRequest = getOAuth2AuthorizationRequest() val authorizationRequest = getOAuth2AuthorizationRequest()
Mockito.`when`(AuthorizedClientConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())) every {
.thenReturn(authorizationRequest) AuthorizedClientConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
Mockito.`when`(AuthorizedClientConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())) } returns authorizationRequest
.thenReturn(authorizationRequest) every {
Mockito.`when`(AuthorizedClientConfig.CLIENT.getTokenResponse(any())) AuthorizedClientConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
.thenReturn(OAuth2AccessTokenResponse } returns authorizationRequest
.withToken("token") every {
.tokenType(OAuth2AccessToken.TokenType.BEARER) AuthorizedClientConfig.CLIENT.getTokenResponse(any())
.build()) } returns OAuth2AccessTokenResponse
.withToken("token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.build()
this.mockMvc.get("/callback") { this.mockMvc.get("/callback") {
param("state", "test") param("state", "test")
param("code", "123") param("code", "123")
} }
verify(AuthorizedClientConfig.CLIENT).getTokenResponse(any()) verify(exactly = 1) { AuthorizedClientConfig.CLIENT.getTokenResponse(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class AuthorizedClientConfig : WebSecurityConfigurerAdapter() { open class AuthorizedClientConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = Mockito.mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> HttpSessionOAuth2AuthorizationRequestRepository()
val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
OAuth2AccessTokenResponseClient {
OAuth2AccessTokenResponse.withToken("some tokenValue").build()
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -135,26 +156,30 @@ class AuthorizationCodeGrantDslTests {
@Test @Test
fun `oauth2Client when custom authorization request resolver then request resolver used`() { fun `oauth2Client when custom authorization request resolver then request resolver used`() {
this.spring.register(RequestResolverConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(RequestResolverConfig::class.java, ClientConfig::class.java).autowire()
val requestResolverConfig = this.spring.context.getBean(RequestResolverConfig::class.java)
val authorizationRequest = getOAuth2AuthorizationRequest()
every {
requestResolverConfig.requestResolver.resolve(any())
} returns authorizationRequest
this.mockMvc.get("/callback") { this.mockMvc.get("/callback") {
param("state", "test") param("state", "test")
param("code", "123") param("code", "123")
} }
verify(RequestResolverConfig.REQUEST_RESOLVER).resolve(any()) verify(exactly = 1) { requestResolverConfig.requestResolver.resolve(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class RequestResolverConfig : WebSecurityConfigurerAdapter() { open class RequestResolverConfig : WebSecurityConfigurerAdapter() {
companion object {
var REQUEST_RESOLVER: OAuth2AuthorizationRequestResolver = Mockito.mock(OAuth2AuthorizationRequestResolver::class.java) val requestResolver: OAuth2AuthorizationRequestResolver = mockk()
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {
oauth2Client { oauth2Client {
authorizationCodeGrant { authorizationCodeGrant {
authorizationRequestResolver = REQUEST_RESOLVER authorizationRequestResolver = requestResolver
} }
} }
authorizeRequests { authorizeRequests {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,23 +16,25 @@
package org.springframework.security.config.web.servlet.oauth2.login package org.springframework.security.config.web.servlet.oauth2.login
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import javax.servlet.http.HttpServletRequest
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.verify
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
@ -54,16 +56,27 @@ class AuthorizationEndpointDslTests {
@Test @Test
fun `oauth2Login when custom client registration repository then repository used`() { fun `oauth2Login when custom client registration repository then repository used`() {
this.spring.register(ResolverConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(ResolverConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(ResolverConfig.RESOLVER)
every { ResolverConfig.RESOLVER.resolve(any()) }
this.mockMvc.get("/oauth2/authorization/google") this.mockMvc.get("/oauth2/authorization/google")
verify(ResolverConfig.RESOLVER).resolve(any()) verify(exactly = 1) { ResolverConfig.RESOLVER.resolve(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class ResolverConfig : WebSecurityConfigurerAdapter() { open class ResolverConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var RESOLVER: OAuth2AuthorizationRequestResolver = Mockito.mock(OAuth2AuthorizationRequestResolver::class.java) val RESOLVER: OAuth2AuthorizationRequestResolver = object : OAuth2AuthorizationRequestResolver {
override fun resolve(
request: HttpServletRequest?
) = OAuth2AuthorizationRequest.authorizationCode().build()
override fun resolve(
request: HttpServletRequest?, clientRegistrationId: String?
) = OAuth2AuthorizationRequest.authorizationCode().build()
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -80,16 +93,20 @@ class AuthorizationEndpointDslTests {
@Test @Test
fun `oauth2Login when custom authorization request repository then repository used`() { fun `oauth2Login when custom authorization request repository then repository used`() {
this.spring.register(RequestRepoConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(RequestRepoConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(RequestRepoConfig.REPOSITORY)
every { RequestRepoConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
this.mockMvc.get("/oauth2/authorization/google") this.mockMvc.get("/oauth2/authorization/google")
verify(RequestRepoConfig.REPOSITORY).saveAuthorizationRequest(any(), any(), any()) verify(exactly = 1) { RequestRepoConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class RequestRepoConfig : WebSecurityConfigurerAdapter() { open class RequestRepoConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
HttpSessionOAuth2AuthorizationRequestRepository()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -106,16 +123,19 @@ class AuthorizationEndpointDslTests {
@Test @Test
fun `oauth2Login when custom authorization uri repository then uri used`() { fun `oauth2Login when custom authorization uri repository then uri used`() {
this.spring.register(AuthorizationUriConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(AuthorizationUriConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(AuthorizationUriConfig.REPOSITORY)
this.mockMvc.get("/connect/google") this.mockMvc.get("/connect/google")
verify(AuthorizationUriConfig.REPOSITORY).saveAuthorizationRequest(any(), any(), any()) verify(exactly = 1) { AuthorizationUriConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class AuthorizationUriConfig : WebSecurityConfigurerAdapter() { open class AuthorizationUriConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
HttpSessionOAuth2AuthorizationRequestRepository()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,10 @@
package org.springframework.security.config.web.servlet.oauth2.login package org.springframework.security.config.web.servlet.oauth2.login
import io.mockk.every
import io.mockk.mockkObject
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers
import org.mockito.Mockito
import org.mockito.Mockito.mock
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -29,15 +28,17 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.config.web.servlet.invoke import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService import org.springframework.security.oauth2.client.userinfo.OAuth2UserService
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.OAuth2AccessToken
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
@ -46,7 +47,6 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User
import org.springframework.security.oauth2.core.user.OAuth2User import org.springframework.security.oauth2.core.user.OAuth2User
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import java.util.*
/** /**
* Tests for [RedirectionEndpointDsl] * Tests for [RedirectionEndpointDsl]
@ -64,6 +64,9 @@ class RedirectionEndpointDslTests {
@Test @Test
fun `oauth2Login when redirection endpoint configured then custom redirection endpoing used`() { fun `oauth2Login when redirection endpoint configured then custom redirection endpoing used`() {
this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(UserServiceConfig.REPOSITORY)
mockkObject(UserServiceConfig.CLIENT)
mockkObject(UserServiceConfig.USER_SERVICE)
val registrationId = "registrationId" val registrationId = "registrationId"
val attributes = HashMap<String, Any>() val attributes = HashMap<String, Any>()
@ -76,15 +79,18 @@ class RedirectionEndpointDslTests {
.redirectUri("http://localhost/callback") .redirectUri("http://localhost/callback")
.attributes(attributes) .attributes(attributes)
.build() .build()
Mockito.`when`(UserServiceConfig.REPOSITORY.removeAuthorizationRequest(ArgumentMatchers.any(), ArgumentMatchers.any())) every {
.thenReturn(authorizationRequest) UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
Mockito.`when`(UserServiceConfig.CLIENT.getTokenResponse(ArgumentMatchers.any())) } returns authorizationRequest
.thenReturn(OAuth2AccessTokenResponse every {
.withToken("token") UserServiceConfig.CLIENT.getTokenResponse(any())
.tokenType(OAuth2AccessToken.TokenType.BEARER) } returns OAuth2AccessTokenResponse
.build()) .withToken("token")
Mockito.`when`(UserServiceConfig.USER_SERVICE.loadUser(ArgumentMatchers.any())) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.thenReturn(DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")) .build()
every {
UserServiceConfig.USER_SERVICE.loadUser(any())
} returns DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")
this.mockMvc.get("/callback") { this.mockMvc.get("/callback") {
param("code", "auth-code") param("code", "auth-code")
@ -96,10 +102,15 @@ class RedirectionEndpointDslTests {
@EnableWebSecurity @EnableWebSecurity
open class UserServiceConfig : WebSecurityConfigurerAdapter() { open class UserServiceConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = mock(OAuth2UserService::class.java) as OAuth2UserService<OAuth2UserRequest, OAuth2User> val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> HttpSessionOAuth2AuthorizationRequestRepository()
var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
OAuth2AccessTokenResponseClient {
OAuth2AccessTokenResponse.withToken("some tokenValue").build()
}
val USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = DefaultOAuth2UserService()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,11 @@
package org.springframework.security.config.web.servlet.oauth2.login package org.springframework.security.config.web.servlet.oauth2.login
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.`when`
import org.mockito.Mockito.mock
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -36,13 +35,13 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.OAuth2AccessToken
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import java.util.*
/** /**
* Tests for [TokenEndpointDsl] * Tests for [TokenEndpointDsl]
@ -60,6 +59,8 @@ class TokenEndpointDslTests {
@Test @Test
fun `oauth2Login when custom access token response client then client used`() { fun `oauth2Login when custom access token response client then client used`() {
this.spring.register(TokenConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(TokenConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(TokenConfig.REPOSITORY)
mockkObject(TokenConfig.CLIENT)
val registrationId = "registrationId" val registrationId = "registrationId"
val attributes = HashMap<String, Any>() val attributes = HashMap<String, Any>()
@ -72,26 +73,34 @@ class TokenEndpointDslTests {
.redirectUri("http://localhost/login/oauth2/code/google") .redirectUri("http://localhost/login/oauth2/code/google")
.attributes(attributes) .attributes(attributes)
.build() .build()
`when`(TokenConfig.REPOSITORY.removeAuthorizationRequest(any(), any())) every {
.thenReturn(authorizationRequest) TokenConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
`when`(TokenConfig.CLIENT.getTokenResponse(any())).thenReturn(OAuth2AccessTokenResponse } returns authorizationRequest
.withToken("token") every {
.tokenType(OAuth2AccessToken.TokenType.BEARER) TokenConfig.CLIENT.getTokenResponse(any())
.build()) } returns OAuth2AccessTokenResponse
.withToken("token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.build()
this.mockMvc.get("/login/oauth2/code/google") { this.mockMvc.get("/login/oauth2/code/google") {
param("code", "auth-code") param("code", "auth-code")
param("state", "test") param("state", "test")
} }
Mockito.verify(TokenConfig.CLIENT).getTokenResponse(any()) verify(exactly = 1) { TokenConfig.CLIENT.getTokenResponse(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class TokenConfig : WebSecurityConfigurerAdapter() { open class TokenConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> HttpSessionOAuth2AuthorizationRequestRepository()
val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
OAuth2AccessTokenResponseClient {
OAuth2AccessTokenResponse.withToken("some tokenValue").build()
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,21 +16,22 @@
package org.springframework.security.config.web.servlet.oauth2.login package org.springframework.security.config.web.servlet.oauth2.login
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito.`when`
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
@ -46,7 +47,6 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User
import org.springframework.security.oauth2.core.user.OAuth2User import org.springframework.security.oauth2.core.user.OAuth2User
import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import java.util.*
/** /**
* Tests for [UserInfoEndpointDsl] * Tests for [UserInfoEndpointDsl]
@ -64,6 +64,9 @@ class UserInfoEndpointDslTests {
@Test @Test
fun `oauth2Login when custom user service then user service used`() { fun `oauth2Login when custom user service then user service used`() {
this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire() this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire()
mockkObject(UserServiceConfig.REPOSITORY)
mockkObject(UserServiceConfig.CLIENT)
mockkObject(UserServiceConfig.USER_SERVICE)
val registrationId = "registrationId" val registrationId = "registrationId"
val attributes = HashMap<String, Any>() val attributes = HashMap<String, Any>()
@ -76,31 +79,35 @@ class UserInfoEndpointDslTests {
.redirectUri("http://localhost/login/oauth2/code/google") .redirectUri("http://localhost/login/oauth2/code/google")
.attributes(attributes) .attributes(attributes)
.build() .build()
`when`(UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any())) every {
.thenReturn(authorizationRequest) UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
`when`(UserServiceConfig.CLIENT.getTokenResponse(any())) } returns authorizationRequest
.thenReturn(OAuth2AccessTokenResponse every {
.withToken("token") UserServiceConfig.CLIENT.getTokenResponse(any())
.tokenType(OAuth2AccessToken.TokenType.BEARER) } returns OAuth2AccessTokenResponse
.build()) .withToken("token")
`when`(UserServiceConfig.USER_SERVICE.loadUser(any())) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.thenReturn(DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")) .build()
every {
UserServiceConfig.USER_SERVICE.loadUser(any())
} returns DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")
this.mockMvc.get("/login/oauth2/code/google") { this.mockMvc.get("/login/oauth2/code/google") {
param("code", "auth-code") param("code", "auth-code")
param("state", "test") param("state", "test")
} }
Mockito.verify(UserServiceConfig.USER_SERVICE).loadUser(any()) verify(exactly = 1) { UserServiceConfig.USER_SERVICE.loadUser(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class UserServiceConfig : WebSecurityConfigurerAdapter() { open class UserServiceConfig : WebSecurityConfigurerAdapter() {
companion object {
var USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = Mockito.mock(OAuth2UserService::class.java) as OAuth2UserService<OAuth2UserRequest, OAuth2User> companion object {
var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = Mockito.mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mockk()
var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest> val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mockk()
} val USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = mockk()
}
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
http { http {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,12 @@
package org.springframework.security.config.web.servlet.oauth2.resourceserver package org.springframework.security.config.web.servlet.oauth2.resourceserver
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.core.convert.converter.Converter import org.springframework.core.convert.converter.Converter
@ -59,7 +62,7 @@ class JwtDslTests {
http { http {
oauth2ResourceServer { oauth2ResourceServer {
jwt { jwt {
jwtDecoder = mock(JwtDecoder::class.java) jwtDecoder = mockk()
} }
} }
} }
@ -87,25 +90,32 @@ class JwtDslTests {
@Test @Test
fun `JWT when custom JWT authentication converter then converter used`() { fun `JWT when custom JWT authentication converter then converter used`() {
this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire() this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire()
`when`(CustomJwtAuthenticationConverterConfig.DECODER.decode(anyString())).thenReturn( mockkObject(CustomJwtAuthenticationConverterConfig.CONVERTER)
Jwt.withTokenValue("token") mockkObject(CustomJwtAuthenticationConverterConfig.DECODER)
.header("alg", "none") every {
.claim(IdTokenClaimNames.SUB, "user") CustomJwtAuthenticationConverterConfig.DECODER.decode(any())
.build()) } returns Jwt.withTokenValue("token")
`when`(CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())) .header("alg", "none")
.thenReturn(TestingAuthenticationToken("test", "this", "ROLE")) .claim(IdTokenClaimNames.SUB, "user")
.build()
every {
CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())
} returns TestingAuthenticationToken("test", "this", "ROLE")
this.mockMvc.get("/") { this.mockMvc.get("/") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(CustomJwtAuthenticationConverterConfig.CONVERTER).convert(any()) verify(exactly = 1) { CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomJwtAuthenticationConverterConfig : WebSecurityConfigurerAdapter() { open class CustomJwtAuthenticationConverterConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var CONVERTER: Converter<Jwt, out AbstractAuthenticationToken> = mock(Converter::class.java) as Converter<Jwt, out AbstractAuthenticationToken> val CONVERTER: Converter<Jwt, out AbstractAuthenticationToken> = Converter { _ ->
var DECODER: JwtDecoder = mock(JwtDecoder::class.java) TestingAuthenticationToken("a", "b", "c")
}
val DECODER: JwtDecoder = JwtDecoder { Jwt.withTokenValue("some tokenValue").build() }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -122,31 +132,32 @@ class JwtDslTests {
} }
@Bean @Bean
open fun jwtDecoder(): JwtDecoder { open fun jwtDecoder(): JwtDecoder = DECODER
return DECODER
}
} }
@Test @Test
fun `JWT when custom JWT decoder set after jwkSetUri then decoder used`() { fun `JWT when custom JWT decoder set after jwkSetUri then decoder used`() {
this.spring.register(JwtDecoderAfterJwkSetUriConfig::class.java).autowire() this.spring.register(JwtDecoderAfterJwkSetUriConfig::class.java).autowire()
`when`(JwtDecoderAfterJwkSetUriConfig.DECODER.decode(anyString())).thenReturn( mockkObject(JwtDecoderAfterJwkSetUriConfig.DECODER)
Jwt.withTokenValue("token") every {
.header("alg", "none") JwtDecoderAfterJwkSetUriConfig.DECODER.decode(any())
.claim(IdTokenClaimNames.SUB, "user") } returns Jwt.withTokenValue("token")
.build()) .header("alg", "none")
.claim(IdTokenClaimNames.SUB, "user")
.build()
this.mockMvc.get("/") { this.mockMvc.get("/") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(JwtDecoderAfterJwkSetUriConfig.DECODER).decode(any()) verify(exactly = 1) { JwtDecoderAfterJwkSetUriConfig.DECODER.decode(any()) }
} }
@EnableWebSecurity @EnableWebSecurity
open class JwtDecoderAfterJwkSetUriConfig : WebSecurityConfigurerAdapter() { open class JwtDecoderAfterJwkSetUriConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var DECODER: JwtDecoder = mock(JwtDecoder::class.java) val DECODER: JwtDecoder = JwtDecoder { Jwt.withTokenValue("some tokenValue").build() }
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,22 +16,23 @@
package org.springframework.security.config.web.servlet.oauth2.resourceserver package org.springframework.security.config.web.servlet.oauth2.resourceserver
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.eq
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.http.* import org.springframework.http.HttpHeaders
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.http.ResponseEntity
import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.core.Authentication import org.springframework.security.core.Authentication
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal
import org.springframework.security.oauth2.jwt.JwtClaimNames import org.springframework.security.oauth2.jwt.JwtClaimNames
import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector
@ -41,6 +42,7 @@ import org.springframework.test.web.servlet.get
import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import org.springframework.web.client.RestOperations import org.springframework.web.client.RestOperations
import org.springframework.web.client.RestTemplate
/** /**
* Tests for [OpaqueTokenDsl] * Tests for [OpaqueTokenDsl]
@ -58,16 +60,19 @@ class OpaqueTokenDslTests {
@Test @Test
fun `opaque token when defaults then uses introspection`() { fun `opaque token when defaults then uses introspection`() {
this.spring.register(DefaultOpaqueConfig::class.java, AuthenticationController::class.java).autowire() this.spring.register(DefaultOpaqueConfig::class.java, AuthenticationController::class.java).autowire()
val headers = HttpHeaders() mockkObject(DefaultOpaqueConfig.REST)
headers.contentType = MediaType.APPLICATION_JSON val headers = HttpHeaders().apply {
contentType = MediaType.APPLICATION_JSON
}
val entity = ResponseEntity("{\n" + val entity = ResponseEntity("{\n" +
" \"active\" : true,\n" + " \"active\" : true,\n" +
" \"sub\": \"test-subject\",\n" + " \"sub\": \"test-subject\",\n" +
" \"scope\": \"message:read\",\n" + " \"scope\": \"message:read\",\n" +
" \"exp\": 4683883211\n" + " \"exp\": 4683883211\n" +
"}", headers, HttpStatus.OK) "}", headers, HttpStatus.OK)
`when`(DefaultOpaqueConfig.REST.exchange(any(RequestEntity::class.java), eq(String::class.java))) every {
.thenReturn(entity) DefaultOpaqueConfig.REST.exchange(any(), eq(String::class.java))
} returns entity
this.mockMvc.get("/authenticated") { this.mockMvc.get("/authenticated") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
@ -79,8 +84,9 @@ class OpaqueTokenDslTests {
@EnableWebSecurity @EnableWebSecurity
open class DefaultOpaqueConfig : WebSecurityConfigurerAdapter() { open class DefaultOpaqueConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var REST: RestOperations = mock(RestOperations::class.java) val REST: RestOperations = RestTemplate()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -95,9 +101,7 @@ class OpaqueTokenDslTests {
} }
@Bean @Bean
open fun rest(): RestOperations { open fun rest(): RestOperations = REST
return REST
}
@Bean @Bean
open fun tokenIntrospectionClient(): NimbusOpaqueTokenIntrospector { open fun tokenIntrospectionClient(): NimbusOpaqueTokenIntrospector {
@ -108,20 +112,26 @@ class OpaqueTokenDslTests {
@Test @Test
fun `opaque token when custom introspector set then introspector used`() { fun `opaque token when custom introspector set then introspector used`() {
this.spring.register(CustomIntrospectorConfig::class.java, AuthenticationController::class.java).autowire() this.spring.register(CustomIntrospectorConfig::class.java, AuthenticationController::class.java).autowire()
`when`(CustomIntrospectorConfig.INTROSPECTOR.introspect(ArgumentMatchers.anyString())) mockkObject(CustomIntrospectorConfig.INTROSPECTOR)
.thenReturn(DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList()))
every {
CustomIntrospectorConfig.INTROSPECTOR.introspect(any())
} returns DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList())
this.mockMvc.get("/authenticated") { this.mockMvc.get("/authenticated") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(CustomIntrospectorConfig.INTROSPECTOR).introspect("token") verify(exactly = 1) { CustomIntrospectorConfig.INTROSPECTOR.introspect("token") }
} }
@EnableWebSecurity @EnableWebSecurity
open class CustomIntrospectorConfig : WebSecurityConfigurerAdapter() { open class CustomIntrospectorConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var INTROSPECTOR: OpaqueTokenIntrospector = mock(OpaqueTokenIntrospector::class.java) val INTROSPECTOR: OpaqueTokenIntrospector = OpaqueTokenIntrospector {
DefaultOAuth2AuthenticatedPrincipal(emptyMap(), emptyList())
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -141,20 +151,25 @@ class OpaqueTokenDslTests {
@Test @Test
fun `opaque token when custom introspector set after client credentials then introspector used`() { fun `opaque token when custom introspector set after client credentials then introspector used`() {
this.spring.register(IntrospectorAfterClientCredentialsConfig::class.java, AuthenticationController::class.java).autowire() this.spring.register(IntrospectorAfterClientCredentialsConfig::class.java, AuthenticationController::class.java).autowire()
`when`(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect(ArgumentMatchers.anyString())) mockkObject(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR)
.thenReturn(DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList())) every {
IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect(any())
} returns DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList())
this.mockMvc.get("/authenticated") { this.mockMvc.get("/authenticated") {
header("Authorization", "Bearer token") header("Authorization", "Bearer token")
} }
verify(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR).introspect("token") verify(exactly = 1) { IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect("token") }
} }
@EnableWebSecurity @EnableWebSecurity
open class IntrospectorAfterClientCredentialsConfig : WebSecurityConfigurerAdapter() { open class IntrospectorAfterClientCredentialsConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
var INTROSPECTOR: OpaqueTokenIntrospector = mock(OpaqueTokenIntrospector::class.java) val INTROSPECTOR: OpaqueTokenIntrospector = OpaqueTokenIntrospector {
DefaultOAuth2AuthenticatedPrincipal(emptyMap(), emptyList())
}
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,11 @@
package org.springframework.security.config.web.servlet.session package org.springframework.security.config.web.servlet.session
import io.mockk.every
import io.mockk.mockkObject
import java.util.Date
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.mockito.Mockito.*
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.Configuration
@ -27,11 +29,12 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
import org.springframework.security.config.test.SpringTestRule import org.springframework.security.config.test.SpringTestRule
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.core.session.SessionInformation import org.springframework.security.core.session.SessionInformation
import org.springframework.security.core.session.SessionRegistry import org.springframework.security.core.session.SessionRegistry
import org.springframework.security.core.session.SessionRegistryImpl
import org.springframework.security.core.userdetails.User import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetailsService import org.springframework.security.core.userdetails.UserDetailsService
import org.springframework.security.config.web.servlet.invoke
import org.springframework.security.provisioning.InMemoryUserDetailsManager import org.springframework.security.provisioning.InMemoryUserDetailsManager
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
import org.springframework.security.web.session.SimpleRedirectSessionInformationExpiredStrategy import org.springframework.security.web.session.SimpleRedirectSessionInformationExpiredStrategy
@ -40,7 +43,6 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
import java.util.*
/** /**
* Tests for [SessionConcurrencyDsl] * Tests for [SessionConcurrencyDsl]
@ -90,11 +92,12 @@ class SessionConcurrencyDslTests {
@Test @Test
fun `session concurrency when expired url then redirects to url`() { fun `session concurrency when expired url then redirects to url`() {
this.spring.register(ExpiredUrlConfig::class.java).autowire() this.spring.register(ExpiredUrlConfig::class.java).autowire()
mockkObject(ExpiredUrlConfig.SESSION_REGISTRY)
val session = MockHttpSession() val session = MockHttpSession()
val sessionInformation = SessionInformation("", session.id, Date(0)) val sessionInformation = SessionInformation("", session.id, Date(0))
sessionInformation.expireNow() sessionInformation.expireNow()
`when`(ExpiredUrlConfig.sessionRegistry.getSessionInformation(any())).thenReturn(sessionInformation) every { ExpiredUrlConfig.SESSION_REGISTRY.getSessionInformation(any()) } returns sessionInformation
this.mockMvc.perform(get("/").session(session)) this.mockMvc.perform(get("/").session(session))
.andExpect(redirectedUrl("/expired-session")) .andExpect(redirectedUrl("/expired-session"))
@ -102,8 +105,9 @@ class SessionConcurrencyDslTests {
@EnableWebSecurity @EnableWebSecurity
open class ExpiredUrlConfig : WebSecurityConfigurerAdapter() { open class ExpiredUrlConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
val sessionRegistry: SessionRegistry = mock(SessionRegistry::class.java) val SESSION_REGISTRY: SessionRegistry = SessionRegistryImpl()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -112,26 +116,25 @@ class SessionConcurrencyDslTests {
sessionConcurrency { sessionConcurrency {
maximumSessions = 1 maximumSessions = 1
expiredUrl = "/expired-session" expiredUrl = "/expired-session"
sessionRegistry = sessionRegistry() sessionRegistry = SESSION_REGISTRY
} }
} }
} }
} }
@Bean @Bean
open fun sessionRegistry(): SessionRegistry { open fun sessionRegistry(): SessionRegistry = SESSION_REGISTRY
return sessionRegistry
}
} }
@Test @Test
fun `session concurrency when expired session strategy then strategy used`() { fun `session concurrency when expired session strategy then strategy used`() {
this.spring.register(ExpiredSessionStrategyConfig::class.java).autowire() this.spring.register(ExpiredSessionStrategyConfig::class.java).autowire()
mockkObject(ExpiredSessionStrategyConfig.SESSION_REGISTRY)
val session = MockHttpSession() val session = MockHttpSession()
val sessionInformation = SessionInformation("", session.id, Date(0)) val sessionInformation = SessionInformation("", session.id, Date(0))
sessionInformation.expireNow() sessionInformation.expireNow()
`when`(ExpiredSessionStrategyConfig.sessionRegistry.getSessionInformation(any())).thenReturn(sessionInformation) every { ExpiredSessionStrategyConfig.SESSION_REGISTRY.getSessionInformation(any()) } returns sessionInformation
this.mockMvc.perform(get("/").session(session)) this.mockMvc.perform(get("/").session(session))
.andExpect(redirectedUrl("/expired-session")) .andExpect(redirectedUrl("/expired-session"))
@ -139,8 +142,9 @@ class SessionConcurrencyDslTests {
@EnableWebSecurity @EnableWebSecurity
open class ExpiredSessionStrategyConfig : WebSecurityConfigurerAdapter() { open class ExpiredSessionStrategyConfig : WebSecurityConfigurerAdapter() {
companion object { companion object {
val sessionRegistry: SessionRegistry = mock(SessionRegistry::class.java) val SESSION_REGISTRY: SessionRegistry = SessionRegistryImpl()
} }
override fun configure(http: HttpSecurity) { override fun configure(http: HttpSecurity) {
@ -149,16 +153,14 @@ class SessionConcurrencyDslTests {
sessionConcurrency { sessionConcurrency {
maximumSessions = 1 maximumSessions = 1
expiredSessionStrategy = SimpleRedirectSessionInformationExpiredStrategy("/expired-session") expiredSessionStrategy = SimpleRedirectSessionInformationExpiredStrategy("/expired-session")
sessionRegistry = sessionRegistry() sessionRegistry = SESSION_REGISTRY
} }
} }
} }
} }
@Bean @Bean
open fun sessionRegistry(): SessionRegistry { open fun sessionRegistry(): SessionRegistry = SESSION_REGISTRY
return sessionRegistry
}
} }
@Configuration @Configuration