Support Spring Data container types for AuthorizeReturnObject

Closes gh-15994

Signed-off-by: Evgeniy Cheban <mister.cheban@gmail.com>
This commit is contained in:
Evgeniy Cheban 2025-04-17 19:25:15 +03:00 committed by Josh Cummings
parent 6d3b54df21
commit fd4f06a66e
2 changed files with 152 additions and 12 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,13 +16,22 @@
package org.springframework.security.config.annotation.method.configuration;
import java.util.List;
import org.springframework.aop.framework.AopInfrastructureBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Role;
import org.springframework.core.Ordered;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.geo.GeoPage;
import org.springframework.data.geo.GeoResult;
import org.springframework.data.geo.GeoResults;
import org.springframework.security.aot.hint.SecurityHintsRegistrar;
import org.springframework.security.authorization.AuthorizationProxyFactory;
import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory;
import org.springframework.security.data.aot.hint.AuthorizeReturnObjectDataHintsRegistrar;
@Configuration(proxyBeanMethods = false)
@ -34,4 +43,45 @@ final class AuthorizationProxyDataConfiguration implements AopInfrastructureBean
return new AuthorizeReturnObjectDataHintsRegistrar(proxyFactory);
}
@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
DataTargetVisitor dataTargetVisitor() {
return new DataTargetVisitor();
}
private static final class DataTargetVisitor implements AuthorizationAdvisorProxyFactory.TargetVisitor, Ordered {
private static final int DEFAULT_ORDER = 200;
@Override
public Object visit(AuthorizationAdvisorProxyFactory proxyFactory, Object target) {
if (target instanceof GeoResults<?> geoResults) {
return new GeoResults<>(proxyFactory.proxy(geoResults.getContent()), geoResults.getAverageDistance());
}
if (target instanceof GeoResult<?> geoResult) {
return new GeoResult<>(proxyFactory.proxy(geoResult.getContent()), geoResult.getDistance());
}
if (target instanceof GeoPage<?> geoPage) {
GeoResults<?> results = new GeoResults<>(proxyFactory.proxy(geoPage.getContent()),
geoPage.getAverageDistance());
return new GeoPage<>(results, geoPage.getPageable(), geoPage.getTotalElements());
}
if (target instanceof PageImpl<?> page) {
List<?> content = proxyFactory.proxy(page.getContent());
return new PageImpl<>(content, page.getPageable(), page.getTotalElements());
}
if (target instanceof SliceImpl<?> slice) {
List<?> content = proxyFactory.proxy(slice.getContent());
return new SliceImpl<>(content, slice.getPageable(), slice.hasNext());
}
return null;
}
@Override
public int getOrder() {
return DEFAULT_ORDER;
}
}
}

View File

@ -63,7 +63,14 @@ import org.springframework.context.annotation.Role;
import org.springframework.context.event.EventListener;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.annotation.AnnotationConfigurationException;
import org.springframework.core.annotation.Order;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.geo.Distance;
import org.springframework.data.geo.GeoPage;
import org.springframework.data.geo.GeoResult;
import org.springframework.data.geo.GeoResults;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
@ -756,6 +763,28 @@ public class PrePostMethodSecurityConfigurationTests {
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findGeoResultByIdWhenAuthorizedResultThenAuthorizes() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
GeoResult<Flight> geoResultFlight = flights.findGeoResultFlightById("1");
Flight flight = geoResultFlight.getContent();
assertThatNoException().isThrownBy(flight::getAltitude);
assertThatNoException().isThrownBy(flight::getSeats);
}
@Test
@WithMockUser(authorities = "seating:read")
public void findGeoResultByIdWhenUnauthorizedResultThenDenies() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
GeoResult<Flight> geoResultFlight = flights.findGeoResultFlightById("1");
Flight flight = geoResultFlight.getContent();
assertThatNoException().isThrownBy(flight::getSeats);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findByIdWhenAuthorizedResponseEntityThenAuthorizes() {
@ -827,6 +856,46 @@ public class PrePostMethodSecurityConfigurationTests {
.doesNotContain("Kevin Mitnick"));
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findPageWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findPage()
.forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName)
.doesNotContain("Kevin Mitnick"));
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findSliceWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findSlice()
.forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName)
.doesNotContain("Kevin Mitnick"));
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findGeoPageWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findGeoPage()
.forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName)
.doesNotContain("Kevin Mitnick"));
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findGeoResultsWhenPostFilterThenFilters() {
this.spring.register(AuthorizeResultConfig.class).autowire();
FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
flights.findGeoResults()
.forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName)
.doesNotContain("Kevin Mitnick"));
}
@Test
@WithMockUser(authorities = "airplane:read")
public void findAllWhenPreFilterThenFilters() {
@ -1762,16 +1831,8 @@ public class PrePostMethodSecurityConfigurationTests {
@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
@Order(1)
static TargetVisitor mock() {
return Mockito.mock(TargetVisitor.class);
}
@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
@Order(0)
static TargetVisitor skipValueTypes() {
return TargetVisitor.defaultsSkipValueTypes();
static TargetVisitor customTargetVisitor() {
return TargetVisitor.of(Mockito.mock(), TargetVisitor.defaultsSkipValueTypes());
}
@Bean
@ -1802,10 +1863,39 @@ public class PrePostMethodSecurityConfigurationTests {
return this.flights.values().iterator();
}
Page<Flight> findPage() {
return new PageImpl<>(new ArrayList<>(this.flights.values()));
}
Slice<Flight> findSlice() {
return new SliceImpl<>(new ArrayList<>(this.flights.values()));
}
GeoPage<Flight> findGeoPage() {
List<GeoResult<Flight>> results = new ArrayList<>();
for (Flight flight : this.flights.values()) {
results.add(new GeoResult<>(flight, new Distance(flight.altitude)));
}
return new GeoPage<>(new GeoResults<>(results));
}
GeoResults<Flight> findGeoResults() {
List<GeoResult<Flight>> results = new ArrayList<>();
for (Flight flight : this.flights.values()) {
results.add(new GeoResult<>(flight, new Distance(flight.altitude)));
}
return new GeoResults<>(results);
}
Flight findById(String id) {
return this.flights.get(id);
}
GeoResult<Flight> findGeoResultFlightById(String id) {
Flight flight = this.flights.get(id);
return new GeoResult<>(flight, new Distance(flight.altitude));
}
Flight save(Flight flight) {
this.flights.put(flight.getId(), flight);
return flight;