diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java index 418fb67121..0c146ea9d2 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,15 +104,18 @@ class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload) .flatMapMany(context -> innerFlux - .skip(1) - .flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) - .transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads)) + .index() + .concatMap(tuple -> justOrIntercept(tuple.getT1(), tuple.getT2())) .transform(securedPayloads -> this.source.requestChannel(securedPayloads)) .subscriberContext(context) ); }); } + private Mono justOrIntercept(Long index, Payload payload) { + return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload); + } + @Override public Mono metadataPush(Payload payload) { return intercept(PayloadExchangeType.METADATA_PUSH, payload) diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java index a925ac676e..fa149e453a 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ package org.springframework.security.rsocket.core; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import org.junit.Test; import org.junit.runner.RunWith; @@ -28,7 +30,9 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import org.springframework.http.MediaType; +import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; @@ -41,6 +45,8 @@ import org.springframework.security.rsocket.core.DefaultPayloadExchange; import org.springframework.security.rsocket.core.PayloadInterceptorRSocket; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; +import reactor.util.context.Context; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -50,10 +56,13 @@ import reactor.test.publisher.TestPublisher; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -315,6 +324,57 @@ public class PayloadInterceptorRSocketTests { verify(this.delegate).requestChannel(any()); } + // gh-9345 + @Test + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { + ExecutorService executors = Executors.newSingleThreadExecutor(); + Payload payload = ByteBufPayload.create("data"); + Payload payloadTwo = ByteBufPayload.create("moredata"); + Payload payloadThree = ByteBufPayload.create("stillmoredata"); + Context ctx = Context.empty(); + Flux payloads = this.payloadResult.flux(); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()) + .thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); + when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> { + Flux input = invocation.getArgument(0); + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) + .transform((data) -> Flux.create((emitter) -> { + Runnable run = () -> data.subscribe(new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(3); + } + + @Override + public void onNext(String s) { + emitter.next(s); + } + + @Override + public void onError(Throwable t) { + emitter.error(t); + } + + @Override + public void onComplete() { + emitter.complete(); + } + }); + executors.execute(run); + })).map(DefaultPayload::create)); + }); + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) + .verifyError(AccessDeniedException.class); + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); + verify(this.delegate).requestChannel(any()); + } + @Test public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops");