PayloadInterceptorRSocket retains all payloads

Flux#skip discards its corresponding elements, meaning that they
aren't intended for reuse. When using RSocket's ByteBufPayloads,
this means that the bytes are releaseed back into RSocket's pool.

Since the downstream request may still need the skipped payload,
we should construct the publisher in a different way so as to
avoid the preemptive release.

Deferring Spring JavaFormat to clarify what changed.

Closes gh-9345
This commit is contained in:
Josh Cummings 2021-05-28 12:18:15 -06:00
parent 6cafa48369
commit b189e0370a
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 68 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2019 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -104,15 +104,18 @@ class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket
return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload)
.flatMapMany(context ->
innerFlux
.skip(1)
.flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
.transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads))
.index()
.concatMap(tuple -> justOrIntercept(tuple.getT1(), tuple.getT2()))
.transform(securedPayloads -> this.source.requestChannel(securedPayloads))
.subscriberContext(context)
);
});
}
private Mono<Payload> justOrIntercept(Long index, Payload payload) {
return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload);
}
@Override
public Mono<Void> metadataPush(Payload payload) {
return intercept(PayloadExchangeType.METADATA_PUSH, payload)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2019 the original author or authors.
* Copyright 2019-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,8 @@ package org.springframework.security.rsocket.core;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.RSocketProxy;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -28,7 +30,9 @@ import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.http.MediaType;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
@ -41,6 +45,8 @@ import org.springframework.security.rsocket.core.DefaultPayloadExchange;
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import reactor.util.context.Context;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@ -50,10 +56,13 @@ import reactor.test.publisher.TestPublisher;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
@ -315,6 +324,57 @@ public class PayloadInterceptorRSocketTests {
verify(this.delegate).requestChannel(any());
}
// gh-9345
@Test
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
ExecutorService executors = Executors.newSingleThreadExecutor();
Payload payload = ByteBufPayload.create("data");
Payload payloadTwo = ByteBufPayload.create("moredata");
Payload payloadThree = ByteBufPayload.create("stillmoredata");
Context ctx = Context.empty();
Flux<Payload> payloads = this.payloadResult.flux();
when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty())
.thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> {
Flux<Payload> input = invocation.getArgument(0);
return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
.transform((data) -> Flux.<String>create((emitter) -> {
Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
@Override
public void onSubscribe(Subscription s) {
s.request(3);
}
@Override
public void onNext(String s) {
emitter.next(s);
}
@Override
public void onError(Throwable t) {
emitter.error(t);
}
@Override
public void onComplete() {
emitter.complete();
}
});
executors.execute(run);
})).map(DefaultPayload::create));
});
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
.then(() -> this.payloadResult.assertSubscribers())
.then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
.assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
.verifyError(AccessDeniedException.class);
verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
verify(this.delegate).requestChannel(any());
}
@Test
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
RuntimeException expected = new RuntimeException("Oops");