PayloadInterceptorRSocket retains all payloads
Flux#skip discards its corresponding elements, meaning that they aren't intended for reuse. When using RSocket's ByteBufPayloads, this means that the bytes are releaseed back into RSocket's pool. Since the downstream request may still need the skipped payload, we should construct the publisher in a different way so as to avoid the preemptive release. Deferring Spring JavaFormat to clarify what changed. Closes gh-9345
This commit is contained in:
parent
6cafa48369
commit
b189e0370a
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2019 the original author or authors.
|
||||
* Copyright 2019-2021 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -104,15 +104,18 @@ class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket
|
|||
return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload)
|
||||
.flatMapMany(context ->
|
||||
innerFlux
|
||||
.skip(1)
|
||||
.flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
|
||||
.transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads))
|
||||
.index()
|
||||
.concatMap(tuple -> justOrIntercept(tuple.getT1(), tuple.getT2()))
|
||||
.transform(securedPayloads -> this.source.requestChannel(securedPayloads))
|
||||
.subscriberContext(context)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
private Mono<Payload> justOrIntercept(Long index, Payload payload) {
|
||||
return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> metadataPush(Payload payload) {
|
||||
return intercept(PayloadExchangeType.METADATA_PUSH, payload)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2019 the original author or authors.
|
||||
* Copyright 2019-2021 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,8 @@ package org.springframework.security.rsocket.core;
|
|||
import io.rsocket.Payload;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.metadata.WellKnownMimeType;
|
||||
import io.rsocket.util.ByteBufPayload;
|
||||
import io.rsocket.util.DefaultPayload;
|
||||
import io.rsocket.util.RSocketProxy;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -28,7 +30,9 @@ import org.mockito.Mock;
|
|||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.reactivestreams.Publisher;
|
||||
import org.reactivestreams.Subscription;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.access.AccessDeniedException;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
|
@ -41,6 +45,8 @@ import org.springframework.security.rsocket.core.DefaultPayloadExchange;
|
|||
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
import reactor.util.context.Context;
|
||||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
|
@ -50,10 +56,13 @@ import reactor.test.publisher.TestPublisher;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
@ -315,6 +324,57 @@ public class PayloadInterceptorRSocketTests {
|
|||
verify(this.delegate).requestChannel(any());
|
||||
}
|
||||
|
||||
// gh-9345
|
||||
@Test
|
||||
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
|
||||
ExecutorService executors = Executors.newSingleThreadExecutor();
|
||||
Payload payload = ByteBufPayload.create("data");
|
||||
Payload payloadTwo = ByteBufPayload.create("moredata");
|
||||
Payload payloadThree = ByteBufPayload.create("stillmoredata");
|
||||
Context ctx = Context.empty();
|
||||
Flux<Payload> payloads = this.payloadResult.flux();
|
||||
when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty())
|
||||
.thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
|
||||
when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> {
|
||||
Flux<Payload> input = invocation.getArgument(0);
|
||||
return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
|
||||
.transform((data) -> Flux.<String>create((emitter) -> {
|
||||
Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
s.request(3);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(String s) {
|
||||
emitter.next(s);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
emitter.error(t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
emitter.complete();
|
||||
}
|
||||
});
|
||||
executors.execute(run);
|
||||
})).map(DefaultPayload::create));
|
||||
});
|
||||
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
|
||||
Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
|
||||
StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
|
||||
.then(() -> this.payloadResult.assertSubscribers())
|
||||
.then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
|
||||
.assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
|
||||
.verifyError(AccessDeniedException.class);
|
||||
verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
|
||||
assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
|
||||
verify(this.delegate).requestChannel(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
|
||||
RuntimeException expected = new RuntimeException("Oops");
|
||||
|
|
Loading…
Reference in New Issue