Add Kotlin support to PreFilter and PostFilter annotations

Closes gh-15093
This commit is contained in:
Blagoja Stamatovski 2024-05-18 15:39:27 +02:00 committed by Josh Cummings
parent fbeb82ef62
commit 63f48167bd
4 changed files with 387 additions and 21 deletions

View File

@ -1,6 +1,8 @@
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import java.util.concurrent.Callable import java.util.concurrent.Callable
apply plugin: 'io.spring.convention.spring-module' apply plugin: 'io.spring.convention.spring-module'
apply plugin: 'kotlin'
dependencies { dependencies {
management platform(project(":spring-security-dependencies")) management platform(project(":spring-security-dependencies"))
@ -31,6 +33,9 @@ dependencies {
testImplementation "org.springframework:spring-test" testImplementation "org.springframework:spring-test"
testImplementation 'org.skyscreamer:jsonassert' testImplementation 'org.skyscreamer:jsonassert'
testImplementation 'org.springframework:spring-test' testImplementation 'org.springframework:spring-test'
testImplementation 'org.jetbrains.kotlin:kotlin-reflect'
testImplementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
testImplementation 'io.mockk:mockk'
testRuntimeOnly 'org.hsqldb:hsqldb' testRuntimeOnly 'org.hsqldb:hsqldb'
} }
@ -57,3 +62,12 @@ Callable<String> springVersion() {
return (Callable<String>) { project.configurations.compileClasspath.resolvedConfiguration.resolvedArtifacts return (Callable<String>) { project.configurations.compileClasspath.resolvedConfiguration.resolvedArtifacts
.find { it.name == 'spring-core' }.moduleVersion.id.version } .find { it.name == 'spring-core' }.moduleVersion.id.version }
} }
tasks.withType(KotlinCompile).configureEach {
kotlinOptions {
languageVersion = "1.7"
apiVersion = "1.7"
freeCompilerArgs = ["-Xjsr305=strict", "-Xsuppress-version-warnings"]
jvmTarget = "17"
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -52,6 +52,7 @@ import org.springframework.util.Assert;
* *
* @author Luke Taylor * @author Luke Taylor
* @author Evgeniy Cheban * @author Evgeniy Cheban
* @author Blagoja Stamatovski
* @since 3.0 * @since 3.0
*/ */
public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpressionHandler<MethodInvocation> public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpressionHandler<MethodInvocation>
@ -109,12 +110,13 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
} }
/** /**
* Filters the {@code filterTarget} object (which must be either a collection, array, * Filters the {@code filterTarget} object (which must be either a {@link Collection},
* map or stream), by evaluating the supplied expression. * {@code Array}, {@link Map} or {@link Stream}), by evaluating the supplied
* expression.
* <p> * <p>
* If a {@code Collection} or {@code Map} is used, the original instance will be * Returns new instances of the same type as the supplied {@code filterTarget} object
* modified to contain the elements for which the permission expression evaluates to * @return The filtered {@link Collection}, {@code Array}, {@link Map} or
* {@code true}. For an array, a new array instance will be returned. * {@link Stream}
*/ */
@Override @Override
public Object filter(Object filterTarget, Expression filterExpression, EvaluationContext ctx) { public Object filter(Object filterTarget, Expression filterExpression, EvaluationContext ctx) {
@ -151,9 +153,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
} }
} }
this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); this.logger.debug(LogMessage.format("Retaining elements: %s", retain));
filterTarget.clear(); try {
filterTarget.addAll(retain); filterTarget.clear();
return filterTarget; filterTarget.addAll(retain);
return filterTarget;
}
catch (UnsupportedOperationException unsupportedOperationException) {
this.logger.debug(LogMessage.format(
"Collection threw exception: %s. Will return a new instance instead of mutating its state.",
unsupportedOperationException.getMessage()));
return retain;
}
} }
private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx, private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx,
@ -178,7 +188,7 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
return filtered; return filtered;
} }
private <K, V> Object filterMap(final Map<K, V> filterTarget, Expression filterExpression, EvaluationContext ctx, private <K, V> Object filterMap(Map<K, V> filterTarget, Expression filterExpression, EvaluationContext ctx,
MethodSecurityExpressionOperations rootObject) { MethodSecurityExpressionOperations rootObject) {
Map<K, V> retain = new LinkedHashMap<>(filterTarget.size()); Map<K, V> retain = new LinkedHashMap<>(filterTarget.size());
this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size())); this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size()));
@ -189,9 +199,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
} }
} }
this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); this.logger.debug(LogMessage.format("Retaining elements: %s", retain));
filterTarget.clear(); try {
filterTarget.putAll(retain); filterTarget.clear();
return filterTarget; filterTarget.putAll(retain);
return filterTarget;
}
catch (UnsupportedOperationException unsupportedOperationException) {
this.logger.debug(LogMessage.format(
"Map threw exception: %s. Will return a new instance instead of mutating its state.",
unsupportedOperationException.getMessage()));
return retain;
}
} }
private Object filterStream(final Stream<?> filterTarget, Expression filterExpression, EvaluationContext ctx, private Object filterStream(final Stream<?> filterTarget, Expression filterExpression, EvaluationContext ctx,

View File

@ -0,0 +1,236 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.access.expression.method
import io.mockk.every
import io.mockk.mockk
import org.aopalliance.intercept.MethodInvocation
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.expression.EvaluationContext
import org.springframework.expression.Expression
import org.springframework.security.core.Authentication
import java.util.stream.Stream
import kotlin.reflect.jvm.internal.impl.load.kotlin.JvmType
import kotlin.reflect.jvm.javaMethod
/**
* @author Blagoja Stamatovski
*/
class DefaultMethodSecurityExpressionHandlerKotlinTests {
private object Foo {
fun bar() {
}
}
private lateinit var authentication: Authentication
private lateinit var methodInvocation: MethodInvocation
private val handler: MethodSecurityExpressionHandler = DefaultMethodSecurityExpressionHandler()
@BeforeEach
fun setUp() {
authentication = mockk()
methodInvocation = mockk()
every { methodInvocation.`this` } returns { Foo }
every { methodInvocation.method } answers { Foo::bar.javaMethod!! }
every { methodInvocation.arguments } answers { arrayOf<JvmType.Object>() }
}
@Test
fun `filters non-empty maps`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val nonEmptyMap: Map<String, String> = mapOf(
"key1" to "value1",
"key2" to "value2",
"key3" to "value3",
)
val filtered: Any = handler.filter(
/* filterTarget = */ nonEmptyMap,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Map::class.java)
val result = (filtered as Map<String, String>)
assertThat(result).hasSize(1)
assertThat(result).containsKey("key2")
assertThat(result).containsValue("value2")
}
@Test
fun `filters empty maps`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val emptyMap: Map<String, String> = emptyMap()
val filtered: Any = handler.filter(
/* filterTarget = */ emptyMap,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Map::class.java)
val result = (filtered as Map<String, String>)
assertThat(result).hasSize(0)
}
@Test
fun `filters non-empty collections`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val nonEmptyCollection: Collection<String> = listOf(
"string1",
"string2",
"string1",
)
val filtered: Any = handler.filter(
/* filterTarget = */ nonEmptyCollection,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Collection::class.java)
val result = (filtered as Collection<String>)
assertThat(result).hasSize(1)
assertThat(result).contains("string2")
}
@Test
fun `filters empty collections`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val emptyCollection: Collection<String> = emptyList()
val filtered: Any = handler.filter(
/* filterTarget = */ emptyCollection,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Collection::class.java)
val result = (filtered as Collection<String>)
assertThat(result).hasSize(0)
}
@Test
fun `filters non-empty arrays`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val nonEmptyArray: Array<String> = arrayOf(
"string1",
"string2",
"string1",
)
val filtered: Any = handler.filter(
/* filterTarget = */ nonEmptyArray,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Array<String>::class.java)
val result = (filtered as Array<String>)
assertThat(result).hasSize(1)
assertThat(result).contains("string2")
}
@Test
fun `filters empty arrays`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val emptyArray: Array<String> = emptyArray()
val filtered: Any = handler.filter(
/* filterTarget = */ emptyArray,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Array<String>::class.java)
val result = (filtered as Array<String>)
assertThat(result).hasSize(0)
}
@Test
fun `filters non-empty streams`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val nonEmptyStream: Stream<String> = listOf(
"string1",
"string2",
"string1",
).stream()
val filtered: Any = handler.filter(
/* filterTarget = */ nonEmptyStream,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Stream::class.java)
val result = (filtered as Stream<String>).toList()
assertThat(result).hasSize(1)
assertThat(result).contains("string2")
}
@Test
fun `filters empty streams`() {
val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
val context: EvaluationContext = handler.createEvaluationContext(
/* authentication = */ authentication,
/* invocation = */ methodInvocation,
)
val emptyStream: Stream<String> = emptyList<String>().stream()
val filtered: Any = handler.filter(
/* filterTarget = */ emptyStream,
/* filterExpression = */ expression,
/* ctx = */ context,
)
assertThat(filtered).isInstanceOf(Stream::class.java)
val result = (filtered as Stream<String>).toList()
assertThat(result).hasSize(0)
}
}

View File

@ -546,9 +546,6 @@ If not, Spring Security will throw an `AccessDeniedException` and return a 403 s
[[use-prefilter]] [[use-prefilter]]
=== Filtering Method Parameters with `@PreFilter` === Filtering Method Parameters with `@PreFilter`
[NOTE]
`@PreFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown
When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PreFilter.html[`@PreFilter`] annotation like so: When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PreFilter.html[`@PreFilter`] annotation like so:
[tabs] [tabs]
@ -566,6 +563,20 @@ public class BankService {
} }
} }
---- ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Component
open class BankService {
@PreFilter("filterObject.owner == authentication.name")
fun updateAccounts(vararg accounts: Account): Collection<Account> {
// ... `accounts` will only contain the accounts owned by the logged-in user
return updated
}
}
----
====== ======
This is meant to filter out any values from `accounts` where the expression `filterObject.owner == authentication.name` fails. This is meant to filter out any values from `accounts` where the expression `filterObject.owner == authentication.name` fails.
@ -591,6 +602,23 @@ void updateAccountsWhenOwnedThenReturns() {
assertThat(updated).containsOnly(ownedBy); assertThat(updated).containsOnly(ownedBy);
} }
---- ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Autowired
lateinit var bankService: BankService
@WithMockUser(username="owner")
@Test
fun updateAccountsWhenOwnedThenReturns() {
val ownedBy: Account = ...
val notOwnedBy: Account = ...
val updated: Collection<Account> = bankService.updateAccounts(ownedBy, notOwnedBy)
assertThat(updated).containsOnly(ownedBy)
}
----
====== ======
[TIP] [TIP]
@ -618,6 +646,23 @@ public Collection<Account> updateAccounts(Map<String, Account> accounts)
@PreFilter("filterObject.owner == authentication.name") @PreFilter("filterObject.owner == authentication.name")
public Collection<Account> updateAccounts(Stream<Account> accounts) public Collection<Account> updateAccounts(Stream<Account> accounts)
---- ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@PreFilter("filterObject.owner == authentication.name")
fun updateAccounts(accounts: Array<Account>): Collection<Account>
@PreFilter("filterObject.owner == authentication.name")
fun updateAccounts(accounts: Collection<Account>): Collection<Account>
@PreFilter("filterObject.value.owner == authentication.name")
fun updateAccounts(accounts: Map<String, Account>): Collection<Account>
@PreFilter("filterObject.owner == authentication.name")
fun updateAccounts(accounts: Stream<Account>): Collection<Account>
----
====== ======
The result is that the above method will only have the `Account` instances where their `owner` attribute matches the logged-in user's `name`. The result is that the above method will only have the `Account` instances where their `owner` attribute matches the logged-in user's `name`.
@ -625,9 +670,6 @@ The result is that the above method will only have the `Account` instances where
[[use-postfilter]] [[use-postfilter]]
=== Filtering Method Results with `@PostFilter` === Filtering Method Results with `@PostFilter`
[NOTE]
`@PostFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown
When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PostFilter.html[`@PostFilter`] annotation like so: When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PostFilter.html[`@PostFilter`] annotation like so:
[tabs] [tabs]
@ -645,6 +687,20 @@ public class BankService {
} }
} }
---- ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Component
open class BankService {
@PreFilter("filterObject.owner == authentication.name")
fun readAccounts(vararg ids: String): Collection<Account> {
// ... the return value will be filtered to only contain the accounts owned by the logged-in user
return accounts
}
}
----
====== ======
This is meant to filter out any values from the return value where the expression `filterObject.owner == authentication.name` fails. This is meant to filter out any values from the return value where the expression `filterObject.owner == authentication.name` fails.
@ -669,6 +725,22 @@ void readAccountsWhenOwnedThenReturns() {
assertThat(accounts.get(0).getOwner()).isEqualTo("owner"); assertThat(accounts.get(0).getOwner()).isEqualTo("owner");
} }
---- ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Autowired
lateinit var bankService: BankService
@WithMockUser(username="owner")
@Test
fun readAccountsWhenOwnedThenReturns() {
val accounts: Collection<Account> = bankService.updateAccounts("owner", "not-owner")
assertThat(accounts).hasSize(1)
assertThat(accounts[0].owner).isEqualTo("owner")
}
----
====== ======
[TIP] [TIP]
@ -678,7 +750,15 @@ void readAccountsWhenOwnedThenReturns() {
For example, the above `readAccounts` declaration will function the same way as the following other three: For example, the above `readAccounts` declaration will function the same way as the following other three:
```java [tabs]
======
Java::
+
[source,java,role="primary"]
----
@PostFilter("filterObject.owner == authentication.name")
public Collection<Account> readAccounts(String... ids)
@PostFilter("filterObject.owner == authentication.name") @PostFilter("filterObject.owner == authentication.name")
public Account[] readAccounts(String... ids) public Account[] readAccounts(String... ids)
@ -687,7 +767,25 @@ public Map<String, Account> readAccounts(String... ids)
@PostFilter("filterObject.owner == authentication.name") @PostFilter("filterObject.owner == authentication.name")
public Stream<Account> readAccounts(String... ids) public Stream<Account> readAccounts(String... ids)
``` ----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@PostFilter("filterObject.owner == authentication.name")
fun readAccounts(vararg ids: String): Collection<Account>
@PostFilter("filterObject.owner == authentication.name")
fun readAccounts(vararg ids: String): Array<Account>
@PostFilter("filterObject.owner == authentication.name")
fun readAccounts(vararg ids: String): Map<String, Account>
@PostFilter("filterObject.owner == authentication.name")
fun readAccounts(vararg ids: String): Stream<Account>
----
======
The result is that the above method will return the `Account` instances where their `owner` attribute matches the logged-in user's `name`. The result is that the above method will return the `Account` instances where their `owner` attribute matches the logged-in user's `name`.