From 63f48167bd96d37ed96813d6ab4246abaef5ba96 Mon Sep 17 00:00:00 2001 From: Blagoja Stamatovski <95382055+call-me-baki@users.noreply.github.com> Date: Sat, 18 May 2024 15:39:27 +0200 Subject: [PATCH] Add Kotlin support to PreFilter and PostFilter annotations Closes gh-15093 --- core/spring-security-core.gradle | 14 ++ ...efaultMethodSecurityExpressionHandler.java | 44 +++- ...hodSecurityExpressionHandlerKotlinTests.kt | 236 ++++++++++++++++++ .../authorization/method-security.adoc | 114 ++++++++- 4 files changed, 387 insertions(+), 21 deletions(-) create mode 100644 core/src/test/kotlin/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerKotlinTests.kt diff --git a/core/spring-security-core.gradle b/core/spring-security-core.gradle index fd5857a20c..ee2a79ee99 100644 --- a/core/spring-security-core.gradle +++ b/core/spring-security-core.gradle @@ -1,6 +1,8 @@ +import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import java.util.concurrent.Callable apply plugin: 'io.spring.convention.spring-module' +apply plugin: 'kotlin' dependencies { management platform(project(":spring-security-dependencies")) @@ -31,6 +33,9 @@ dependencies { testImplementation "org.springframework:spring-test" testImplementation 'org.skyscreamer:jsonassert' testImplementation 'org.springframework:spring-test' + testImplementation 'org.jetbrains.kotlin:kotlin-reflect' + testImplementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8' + testImplementation 'io.mockk:mockk' testRuntimeOnly 'org.hsqldb:hsqldb' } @@ -57,3 +62,12 @@ Callable springVersion() { return (Callable) { project.configurations.compileClasspath.resolvedConfiguration.resolvedArtifacts .find { it.name == 'spring-core' }.moduleVersion.id.version } } + +tasks.withType(KotlinCompile).configureEach { + kotlinOptions { + languageVersion = "1.7" + apiVersion = "1.7" + freeCompilerArgs = ["-Xjsr305=strict", "-Xsuppress-version-warnings"] + jvmTarget = "17" + } +} diff --git a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java index 76cf8cf2dc..8d9175bf5b 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,7 @@ import org.springframework.util.Assert; * * @author Luke Taylor * @author Evgeniy Cheban + * @author Blagoja Stamatovski * @since 3.0 */ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpressionHandler @@ -109,12 +110,13 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr } /** - * Filters the {@code filterTarget} object (which must be either a collection, array, - * map or stream), by evaluating the supplied expression. + * Filters the {@code filterTarget} object (which must be either a {@link Collection}, + * {@code Array}, {@link Map} or {@link Stream}), by evaluating the supplied + * expression. *

- * If a {@code Collection} or {@code Map} is used, the original instance will be - * modified to contain the elements for which the permission expression evaluates to - * {@code true}. For an array, a new array instance will be returned. + * Returns new instances of the same type as the supplied {@code filterTarget} object + * @return The filtered {@link Collection}, {@code Array}, {@link Map} or + * {@link Stream} */ @Override public Object filter(Object filterTarget, Expression filterExpression, EvaluationContext ctx) { @@ -151,9 +153,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr } } this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); - filterTarget.clear(); - filterTarget.addAll(retain); - return filterTarget; + try { + filterTarget.clear(); + filterTarget.addAll(retain); + return filterTarget; + } + catch (UnsupportedOperationException unsupportedOperationException) { + this.logger.debug(LogMessage.format( + "Collection threw exception: %s. Will return a new instance instead of mutating its state.", + unsupportedOperationException.getMessage())); + return retain; + } } private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx, @@ -178,7 +188,7 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr return filtered; } - private Object filterMap(final Map filterTarget, Expression filterExpression, EvaluationContext ctx, + private Object filterMap(Map filterTarget, Expression filterExpression, EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { Map retain = new LinkedHashMap<>(filterTarget.size()); this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size())); @@ -189,9 +199,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr } } this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); - filterTarget.clear(); - filterTarget.putAll(retain); - return filterTarget; + try { + filterTarget.clear(); + filterTarget.putAll(retain); + return filterTarget; + } + catch (UnsupportedOperationException unsupportedOperationException) { + this.logger.debug(LogMessage.format( + "Map threw exception: %s. Will return a new instance instead of mutating its state.", + unsupportedOperationException.getMessage())); + return retain; + } } private Object filterStream(final Stream filterTarget, Expression filterExpression, EvaluationContext ctx, diff --git a/core/src/test/kotlin/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerKotlinTests.kt b/core/src/test/kotlin/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerKotlinTests.kt new file mode 100644 index 0000000000..af13894fe1 --- /dev/null +++ b/core/src/test/kotlin/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerKotlinTests.kt @@ -0,0 +1,236 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.access.expression.method + +import io.mockk.every +import io.mockk.mockk +import org.aopalliance.intercept.MethodInvocation +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.springframework.expression.EvaluationContext +import org.springframework.expression.Expression +import org.springframework.security.core.Authentication +import java.util.stream.Stream +import kotlin.reflect.jvm.internal.impl.load.kotlin.JvmType +import kotlin.reflect.jvm.javaMethod + +/** + * @author Blagoja Stamatovski + */ +class DefaultMethodSecurityExpressionHandlerKotlinTests { + private object Foo { + fun bar() { + } + } + + private lateinit var authentication: Authentication + private lateinit var methodInvocation: MethodInvocation + + private val handler: MethodSecurityExpressionHandler = DefaultMethodSecurityExpressionHandler() + + @BeforeEach + fun setUp() { + authentication = mockk() + methodInvocation = mockk() + + every { methodInvocation.`this` } returns { Foo } + every { methodInvocation.method } answers { Foo::bar.javaMethod!! } + every { methodInvocation.arguments } answers { arrayOf() } + } + + @Test + fun `filters non-empty maps`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val nonEmptyMap: Map = mapOf( + "key1" to "value1", + "key2" to "value2", + "key3" to "value3", + ) + + val filtered: Any = handler.filter( + /* filterTarget = */ nonEmptyMap, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Map::class.java) + val result = (filtered as Map) + assertThat(result).hasSize(1) + assertThat(result).containsKey("key2") + assertThat(result).containsValue("value2") + } + + @Test + fun `filters empty maps`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val emptyMap: Map = emptyMap() + + val filtered: Any = handler.filter( + /* filterTarget = */ emptyMap, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Map::class.java) + val result = (filtered as Map) + assertThat(result).hasSize(0) + } + + @Test + fun `filters non-empty collections`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val nonEmptyCollection: Collection = listOf( + "string1", + "string2", + "string1", + ) + + val filtered: Any = handler.filter( + /* filterTarget = */ nonEmptyCollection, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Collection::class.java) + val result = (filtered as Collection) + assertThat(result).hasSize(1) + assertThat(result).contains("string2") + } + + @Test + fun `filters empty collections`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val emptyCollection: Collection = emptyList() + + val filtered: Any = handler.filter( + /* filterTarget = */ emptyCollection, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Collection::class.java) + val result = (filtered as Collection) + assertThat(result).hasSize(0) + } + + @Test + fun `filters non-empty arrays`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val nonEmptyArray: Array = arrayOf( + "string1", + "string2", + "string1", + ) + + val filtered: Any = handler.filter( + /* filterTarget = */ nonEmptyArray, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Array::class.java) + val result = (filtered as Array) + assertThat(result).hasSize(1) + assertThat(result).contains("string2") + } + + @Test + fun `filters empty arrays`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val emptyArray: Array = emptyArray() + + val filtered: Any = handler.filter( + /* filterTarget = */ emptyArray, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Array::class.java) + val result = (filtered as Array) + assertThat(result).hasSize(0) + } + + @Test + fun `filters non-empty streams`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val nonEmptyStream: Stream = listOf( + "string1", + "string2", + "string1", + ).stream() + + val filtered: Any = handler.filter( + /* filterTarget = */ nonEmptyStream, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Stream::class.java) + val result = (filtered as Stream).toList() + assertThat(result).hasSize(1) + assertThat(result).contains("string2") + } + + @Test + fun `filters empty streams`() { + val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'") + val context: EvaluationContext = handler.createEvaluationContext( + /* authentication = */ authentication, + /* invocation = */ methodInvocation, + ) + val emptyStream: Stream = emptyList().stream() + + val filtered: Any = handler.filter( + /* filterTarget = */ emptyStream, + /* filterExpression = */ expression, + /* ctx = */ context, + ) + + assertThat(filtered).isInstanceOf(Stream::class.java) + val result = (filtered as Stream).toList() + assertThat(result).hasSize(0) + } +} diff --git a/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc b/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc index af6fad39da..2c81947cdd 100644 --- a/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc +++ b/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc @@ -546,9 +546,6 @@ If not, Spring Security will throw an `AccessDeniedException` and return a 403 s [[use-prefilter]] === Filtering Method Parameters with `@PreFilter` -[NOTE] -`@PreFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown - When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PreFilter.html[`@PreFilter`] annotation like so: [tabs] @@ -566,6 +563,20 @@ public class BankService { } } ---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Component +open class BankService { + @PreFilter("filterObject.owner == authentication.name") + fun updateAccounts(vararg accounts: Account): Collection { + // ... `accounts` will only contain the accounts owned by the logged-in user + return updated + } +} +---- ====== This is meant to filter out any values from `accounts` where the expression `filterObject.owner == authentication.name` fails. @@ -591,6 +602,23 @@ void updateAccountsWhenOwnedThenReturns() { assertThat(updated).containsOnly(ownedBy); } ---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Autowired +lateinit var bankService: BankService + +@WithMockUser(username="owner") +@Test +fun updateAccountsWhenOwnedThenReturns() { + val ownedBy: Account = ... + val notOwnedBy: Account = ... + val updated: Collection = bankService.updateAccounts(ownedBy, notOwnedBy) + assertThat(updated).containsOnly(ownedBy) +} +---- ====== [TIP] @@ -618,6 +646,23 @@ public Collection updateAccounts(Map accounts) @PreFilter("filterObject.owner == authentication.name") public Collection updateAccounts(Stream accounts) ---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@PreFilter("filterObject.owner == authentication.name") +fun updateAccounts(accounts: Array): Collection + +@PreFilter("filterObject.owner == authentication.name") +fun updateAccounts(accounts: Collection): Collection + +@PreFilter("filterObject.value.owner == authentication.name") +fun updateAccounts(accounts: Map): Collection + +@PreFilter("filterObject.owner == authentication.name") +fun updateAccounts(accounts: Stream): Collection +---- ====== The result is that the above method will only have the `Account` instances where their `owner` attribute matches the logged-in user's `name`. @@ -625,9 +670,6 @@ The result is that the above method will only have the `Account` instances where [[use-postfilter]] === Filtering Method Results with `@PostFilter` -[NOTE] -`@PostFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown - When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PostFilter.html[`@PostFilter`] annotation like so: [tabs] @@ -645,6 +687,20 @@ public class BankService { } } ---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Component +open class BankService { + @PreFilter("filterObject.owner == authentication.name") + fun readAccounts(vararg ids: String): Collection { + // ... the return value will be filtered to only contain the accounts owned by the logged-in user + return accounts + } +} +---- ====== This is meant to filter out any values from the return value where the expression `filterObject.owner == authentication.name` fails. @@ -669,6 +725,22 @@ void readAccountsWhenOwnedThenReturns() { assertThat(accounts.get(0).getOwner()).isEqualTo("owner"); } ---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Autowired +lateinit var bankService: BankService + +@WithMockUser(username="owner") +@Test +fun readAccountsWhenOwnedThenReturns() { + val accounts: Collection = bankService.updateAccounts("owner", "not-owner") + assertThat(accounts).hasSize(1) + assertThat(accounts[0].owner).isEqualTo("owner") +} +---- ====== [TIP] @@ -678,7 +750,15 @@ void readAccountsWhenOwnedThenReturns() { For example, the above `readAccounts` declaration will function the same way as the following other three: -```java +[tabs] +====== +Java:: ++ +[source,java,role="primary"] +---- +@PostFilter("filterObject.owner == authentication.name") +public Collection readAccounts(String... ids) + @PostFilter("filterObject.owner == authentication.name") public Account[] readAccounts(String... ids) @@ -687,7 +767,25 @@ public Map readAccounts(String... ids) @PostFilter("filterObject.owner == authentication.name") public Stream readAccounts(String... ids) -``` +---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@PostFilter("filterObject.owner == authentication.name") +fun readAccounts(vararg ids: String): Collection + +@PostFilter("filterObject.owner == authentication.name") +fun readAccounts(vararg ids: String): Array + +@PostFilter("filterObject.owner == authentication.name") +fun readAccounts(vararg ids: String): Map + +@PostFilter("filterObject.owner == authentication.name") +fun readAccounts(vararg ids: String): Stream +---- +====== The result is that the above method will return the `Account` instances where their `owner` attribute matches the logged-in user's `name`.