diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java index 5e8839c139..a5a849096b 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java @@ -72,6 +72,7 @@ class PayloadSocketAcceptor implements SocketAcceptor { return intercept(setup, dataMimeType, metadataMimeType) .flatMap(ctx -> this.delegate.accept(setup, sendingSocket) .map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx)) + .subscriberContext(ctx) ); } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java new file mode 100644 index 0000000000..f434ed4c70 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.core; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import reactor.core.publisher.Mono; + +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; + +/** + * A {@link SocketAcceptor} that captures the {@link SecurityContext} and then continues with the {@link RSocket} + * @author Rob Winch + */ +class CaptureSecurityContextSocketAcceptor implements SocketAcceptor { + private final RSocket accept; + + private SecurityContext securityContext; + + CaptureSecurityContextSocketAcceptor(RSocket accept) { + this.accept = accept; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + return ReactiveSecurityContextHolder.getContext() + .doOnNext(securityContext -> this.securityContext = securityContext) + .thenReturn(this.accept); + } + + public SecurityContext getSecurityContext() { + return this.securityContext; + } +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java index 943fc978b9..69b7e2356d 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java @@ -16,6 +16,10 @@ package org.springframework.security.rsocket.core; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -27,16 +31,16 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + import org.springframework.http.MediaType; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; -import org.springframework.security.rsocket.core.PayloadInterceptorRSocket; -import org.springframework.security.rsocket.core.PayloadSocketAcceptor; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -144,6 +148,27 @@ public class PayloadSocketAcceptorTests { assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } + + @Test + // gh-8654 + public void acceptWhenDelegateAcceptRequiresReactiveSecurityContext() { + when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE); + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + SecurityContext expectedSecurityContext = new SecurityContextImpl(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + CaptureSecurityContextSocketAcceptor captureSecurityContext = new CaptureSecurityContextSocketAcceptor(this.rSocket); + PayloadInterceptor authenticateInterceptor = (exchange, chain) -> { + Context withSecurityContext = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedSecurityContext)); + return chain.next(exchange) + .subscriberContext(withSecurityContext); + }; + List interceptors = Arrays.asList(authenticateInterceptor); + this.acceptor = new PayloadSocketAcceptor(captureSecurityContext, interceptors); + + this.acceptor.accept(this.setupPayload, this.rSocket).block(); + + assertThat(captureSecurityContext.getSecurityContext()).isEqualTo(expectedSecurityContext); + } + private PayloadExchange captureExchange() { when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket)); when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());