From 5a4eded696a32be563cf403e3073e5643c3b27ad Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 4 Sep 2019 19:16:56 -0500 Subject: [PATCH] Add RSocket Support Fixes gh-7360 --- config/spring-security-config.gradle | 3 + .../rsocket/EnableRSocketSecurity.java | 39 ++ .../annotation/rsocket/RSocketSecurity.java | 313 +++++++++++ .../rsocket/RSocketSecurityConfiguration.java | 84 +++ .../annotation/rsocket/HelloHandler.java | 41 ++ .../config/annotation/rsocket/JwtITests.java | 182 +++++++ ...RSocketMessageHandlerConnectionITests.java | 246 +++++++++ .../rsocket/RSocketMessageHandlerITests.java | 312 +++++++++++ etc/checkstyle/header.txt | 4 +- gradle/dependency-management.gradle | 8 +- rsocket/spring-security-rsocket.gradle | 9 + .../ContextPayloadInterceptorChain.java | 96 ++++ .../interceptor/DefaultPayloadExchange.java | 70 +++ .../rsocket/interceptor/PayloadExchange.java | 36 ++ .../interceptor/PayloadExchangeType.java | 80 +++ .../interceptor/PayloadInterceptor.java | 38 ++ .../interceptor/PayloadInterceptorChain.java | 34 ++ .../PayloadInterceptorRSocket.java | 140 +++++ .../interceptor/PayloadSocketAcceptor.java | 99 ++++ .../PayloadSocketAcceptorInterceptor.java | 66 +++ .../AnonymousPayloadInterceptor.java | 83 +++ .../AuthenticationPayloadInterceptor.java | 74 +++ ...uthenticationPayloadExchangeConverter.java | 60 +++ .../BearerPayloadExchangeConverter.java | 54 ++ ...ayloadExchangeAuthenticationConverter.java | 30 ++ .../AuthorizationPayloadInterceptor.java | 53 ++ ...geMatcherReactiveAuthorizationManager.java | 82 +++ .../metadata/BasicAuthenticationDecoder.java | 76 +++ .../metadata/BasicAuthenticationEncoder.java | 76 +++ .../rsocket/metadata/BearerTokenMetadata.java | 47 ++ .../metadata/UsernamePasswordMetadata.java | 55 ++ .../PayloadExchangeAuthorizationContext.java | 48 ++ .../rsocket/util/PayloadExchangeMatcher.java | 89 +++ .../util/PayloadExchangeMatcherEntry.java | 38 ++ .../rsocket/util/PayloadExchangeMatchers.java | 57 ++ .../util/RoutePayloadExchangeMatcher.java | 61 +++ .../AnonymousPayloadInterceptorTests.java | 108 ++++ ...AuthenticationPayloadInterceptorChain.java | 45 ++ ...AuthenticationPayloadInterceptorTests.java | 148 +++++ .../AuthorizationPayloadInterceptorTests.java | 118 ++++ .../PayloadInterceptorRSocketTests.java | 509 ++++++++++++++++++ ...PayloadSocketAcceptorInterceptorTests.java | 121 +++++ .../PayloadSocketAcceptorTests.java | 160 ++++++ ...tcherReactiveAuthorizationManagerTest.java | 108 ++++ .../BasicAuthenticationDecoderTests.java | 54 ++ .../RoutePayloadExchangeMatcherTests.java | 116 ++++ 46 files changed, 4366 insertions(+), 4 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java create mode 100644 config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java create mode 100644 config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java create mode 100644 config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java create mode 100644 rsocket/spring-security-rsocket.gradle create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java create mode 100644 rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java create mode 100644 rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java diff --git a/config/spring-security-config.gradle b/config/spring-security-config.gradle index 580a2bdbc2..718d717a99 100644 --- a/config/spring-security-config.gradle +++ b/config/spring-security-config.gradle @@ -15,10 +15,12 @@ dependencies { optional project(':spring-security-oauth2-jose') optional project(':spring-security-oauth2-resource-server') optional project(':spring-security-openid') + optional project(':spring-security-rsocket') optional project(':spring-security-web') optional 'io.projectreactor:reactor-core' optional 'org.aspectj:aspectjweaver' optional 'org.springframework:spring-jdbc' + optional 'org.springframework:spring-messaging' optional 'org.springframework:spring-tx' optional 'org.springframework:spring-webmvc' optional'org.springframework:spring-web' @@ -39,6 +41,7 @@ dependencies { testCompile 'com.squareup.okhttp3:mockwebserver' testCompile 'ch.qos.logback:logback-classic' testCompile 'io.projectreactor.netty:reactor-netty' + testCompile 'io.rsocket:rsocket-transport-netty' testCompile 'javax.annotation:jsr250-api:1.0' testCompile 'javax.xml.bind:jaxb-api' testCompile 'ldapsdk:ldapsdk:4.1' diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java new file mode 100644 index 0000000000..e4dce801f9 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import org.springframework.context.annotation.Import; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Add this annotation to a {@code Configuration} class to have Spring Security + * {@link RSocketSecurity} support added. + * + * @author Rob Winch + * @since 5.2 + * @see RSocketSecurity + */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@Import({ RSocketSecurityConfiguration.class }) +public @interface EnableRSocketSecurity { } diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java new file mode 100644 index 0000000000..274dc4b539 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java @@ -0,0 +1,313 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.core.ResolvableType; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager; +import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager; +import org.springframework.security.authorization.AuthorizationDecision; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.config.Customizer; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager; +import org.springframework.security.rsocket.interceptor.PayloadInterceptor; +import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor; +import org.springframework.security.rsocket.interceptor.authentication.AnonymousPayloadInterceptor; +import org.springframework.security.rsocket.interceptor.authentication.AuthenticationPayloadInterceptor; +import org.springframework.security.rsocket.interceptor.authentication.BearerPayloadExchangeConverter; +import org.springframework.security.rsocket.interceptor.authorization.AuthorizationPayloadInterceptor; +import org.springframework.security.rsocket.interceptor.authorization.PayloadExchangeMatcherReactiveAuthorizationManager; +import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext; +import org.springframework.security.rsocket.util.PayloadExchangeMatcher; +import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry; +import org.springframework.security.rsocket.util.PayloadExchangeMatchers; +import org.springframework.security.rsocket.util.RoutePayloadExchangeMatcher; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.List; + +/** + * Allows configuring RSocket based security. + * + * A minimal example can be found below: + * + *
+ * @EnableRSocketSecurity
+ * public class SecurityConfig {
+ *     // @formatter:off
+ *     @Bean
+ *     PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ *         rsocket
+ *             .authorizePayload(authorize ->
+ *                 authorize
+ *                     .anyRequest().authenticated()
+ *             );
+ *         return rsocket.build();
+ *     }
+ *     // @formatter:on
+ *
+ *     // @formatter:off
+ *     @Bean
+ *     public MapReactiveUserDetailsService userDetailsService() {
+ *          UserDetails user = User.withDefaultPasswordEncoder()
+ *               .username("user")
+ *               .password("password")
+ *               .roles("USER")
+ *               .build();
+ *          return new MapReactiveUserDetailsService(user);
+ *     }
+ *     // @formatter:on
+ * }
+ * 
+ * + * A more advanced configuration can be seen below: + * + *
+ * @EnableRSocketSecurity
+ * public class SecurityConfig {
+ *     // @formatter:off
+ *     @Bean
+ *     PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
+ *         rsocket
+ *             .authorizePayload(authorize ->
+ *                 authorize
+ *                     // must have ROLE_SETUP to make connection
+ *                     .setup().hasRole("SETUP")
+ *                      // must have ROLE_ADMIN for routes starting with "admin."
+ *                     .route("admin.*").hasRole("ADMIN")
+ *                     // any other request must be authenticated for
+ *                     .anyRequest().authenticated()
+ *             );
+ *         return rsocket.build();
+ *     }
+ *     // @formatter:on
+ * }
+ * 
+ * @author Rob Winch + * @since 5.2 + */ +public class RSocketSecurity { + + private BasicAuthenticationSpec basicAuthSpec; + + private JwtSpec jwtSpec; + + private AuthorizePayloadsSpec authorizePayload; + + private ApplicationContext context; + + private ReactiveAuthenticationManager authenticationManager; + + public RSocketSecurity authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + public RSocketSecurity basicAuthentication(Customizer basic) { + if (this.basicAuthSpec == null) { + this.basicAuthSpec = new BasicAuthenticationSpec(); + } + basic.customize(this.basicAuthSpec); + return this; + } + + public class BasicAuthenticationSpec { + private ReactiveAuthenticationManager authenticationManager; + + public BasicAuthenticationSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager == null) { + return RSocketSecurity.this.authenticationManager; + } + return this.authenticationManager; + } + + protected AuthenticationPayloadInterceptor build() { + ReactiveAuthenticationManager manager = getAuthenticationManager(); + return new AuthenticationPayloadInterceptor(manager); + } + + private BasicAuthenticationSpec() {} + } + + public RSocketSecurity jwt(Customizer jwt) { + if (this.jwtSpec == null) { + this.jwtSpec = new JwtSpec(); + } + jwt.customize(this.jwtSpec); + return this; + } + + public class JwtSpec { + private ReactiveAuthenticationManager authenticationManager; + + public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager != null) { + return this.authenticationManager; + } + ReactiveJwtDecoder jwtDecoder = getBeanOrNull(ReactiveJwtDecoder.class); + if (jwtDecoder != null) { + this.authenticationManager = new JwtReactiveAuthenticationManager(jwtDecoder); + return this.authenticationManager; + } + return RSocketSecurity.this.authenticationManager; + } + + protected AuthenticationPayloadInterceptor build() { + ReactiveAuthenticationManager manager = getAuthenticationManager(); + AuthenticationPayloadInterceptor result = new AuthenticationPayloadInterceptor(manager); + result.setAuthenticationConverter(new BearerPayloadExchangeConverter()); + return result; + } + + private JwtSpec() {} + } + + public RSocketSecurity authorizePayload(Customizer authorize) { + if (this.authorizePayload == null) { + this.authorizePayload = new AuthorizePayloadsSpec(); + } + authorize.customize(this.authorizePayload); + return this; + } + + public PayloadSocketAcceptorInterceptor build() { + PayloadSocketAcceptorInterceptor interceptor = new PayloadSocketAcceptorInterceptor( + payloadInterceptors()); + RSocketMessageHandler handler = getBean(RSocketMessageHandler.class); + interceptor.setDefaultDataMimeType(handler.getDefaultDataMimeType()); + interceptor.setDefaultMetadataMimeType(handler.getDefaultMetadataMimeType()); + return interceptor; + } + + private List payloadInterceptors() { + List payloadInterceptors = new ArrayList<>(); + + if (this.basicAuthSpec != null) { + payloadInterceptors.add(this.basicAuthSpec.build()); + } + if (this.jwtSpec != null) { + payloadInterceptors.add(this.jwtSpec.build()); + } + payloadInterceptors.add(new AnonymousPayloadInterceptor("anonymousUser")); + + if (this.authorizePayload != null) { + payloadInterceptors.add(this.authorizePayload.build()); + } + return payloadInterceptors; + } + + public class AuthorizePayloadsSpec { + + private PayloadExchangeMatcherReactiveAuthorizationManager.Builder authzBuilder = + PayloadExchangeMatcherReactiveAuthorizationManager.builder(); + + public Access setup() { + return matcher(PayloadExchangeMatchers.setup()); + } + + public Access anyRequest() { + return matcher(PayloadExchangeMatchers.anyExchange()); + } + + protected AuthorizationPayloadInterceptor build() { + return new AuthorizationPayloadInterceptor(this.authzBuilder.build()); + } + + public Access route(String pattern) { + RSocketMessageHandler handler = getBean(RSocketMessageHandler.class); + PayloadExchangeMatcher matcher = new RoutePayloadExchangeMatcher( + handler.getMetadataExtractor(), + handler.getRouteMatcher(), + pattern); + return matcher(matcher); + } + + public Access matcher(PayloadExchangeMatcher matcher) { + return new Access(matcher); + } + + public class Access { + + private final PayloadExchangeMatcher matcher; + + private Access(PayloadExchangeMatcher matcher) { + this.matcher = matcher; + } + + public AuthorizePayloadsSpec authenticated() { + return access(AuthenticatedReactiveAuthorizationManager.authenticated()); + } + + public AuthorizePayloadsSpec hasRole(String role) { + return access(AuthorityReactiveAuthorizationManager.hasRole(role)); + } + + public AuthorizePayloadsSpec permitAll() { + return access((a, ctx) -> Mono + .just(new AuthorizationDecision(true))); + } + + public AuthorizePayloadsSpec access( + ReactiveAuthorizationManager authorization) { + AuthorizePayloadsSpec.this.authzBuilder.add(new PayloadExchangeMatcherEntry<>(this.matcher, authorization)); + return AuthorizePayloadsSpec.this; + } + } + } + + private T getBean(Class beanClass) { + if (this.context == null) { + return null; + } + return this.context.getBean(beanClass); + } + + private T getBeanOrNull(Class beanClass) { + return getBeanOrNull(ResolvableType.forClass(beanClass)); + } + + private T getBeanOrNull(ResolvableType type) { + if (this.context == null) { + return null; + } + String[] names = this.context.getBeanNamesForType(type); + if (names.length == 1) { + return (T) this.context.getBean(names[0]); + } + return null; + } + + protected void setApplicationContext(ApplicationContext applicationContext) + throws BeansException { + this.context = applicationContext; + } +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java new file mode 100644 index 0000000000..fdf9bd31bc --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java @@ -0,0 +1,84 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.authentication.UserDetailsRepositoryReactiveAuthenticationManager; +import org.springframework.security.core.userdetails.ReactiveUserDetailsService; +import org.springframework.security.crypto.password.PasswordEncoder; + +/** + * @author Rob Winch + * @since 5.2 + */ +@Configuration(proxyBeanMethods = false) +class RSocketSecurityConfiguration { + + private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.rsocket.RSocketSecurityConfiguration."; + private static final String RSOCKET_SECURITY_BEAN_NAME = BEAN_NAME_PREFIX + "rsocketSecurity"; + + private ReactiveAuthenticationManager authenticationManager; + + private ReactiveUserDetailsService reactiveUserDetailsService; + + private PasswordEncoder passwordEncoder; + + @Autowired(required = false) + void setAuthenticationManager( + ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + } + + @Autowired(required = false) + void setUserDetailsService(ReactiveUserDetailsService userDetailsService) { + this.reactiveUserDetailsService = userDetailsService; + } + + @Autowired(required = false) + void setPasswordEncoder(PasswordEncoder passwordEncoder) { + this.passwordEncoder = passwordEncoder; + } + + @Bean(name = RSOCKET_SECURITY_BEAN_NAME) + @Scope("prototype") + public RSocketSecurity rsocketSecurity(ApplicationContext context) { + RSocketSecurity security = new RSocketSecurity() + .authenticationManager(authenticationManager()); + security.setApplicationContext(context); + return security; + } + + private ReactiveAuthenticationManager authenticationManager() { + if (this.authenticationManager != null) { + return this.authenticationManager; + } + if (this.reactiveUserDetailsService != null) { + UserDetailsRepositoryReactiveAuthenticationManager manager = + new UserDetailsRepositoryReactiveAuthenticationManager(this.reactiveUserDetailsService); + if (this.passwordEncoder != null) { + manager.setPasswordEncoder(this.passwordEncoder); + } + return manager; + } + return null; + } +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java new file mode 100644 index 0000000000..501822a65c --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java @@ -0,0 +1,41 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.rsocket; + +import io.rsocket.AbstractRSocket; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.util.ByteBufPayload; +import reactor.core.publisher.Mono; + +public class HelloHandler implements SocketAcceptor { + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + return Mono.just( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + String data = payload.getDataUtf8(); + payload.release(); + System.out.println("Got " + data); + return Mono.just(ByteBufPayload.create("Hello " + data)); + } + }); + } +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java new file mode 100644 index 0000000000..e48f269b16 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.rsocket; + +import io.rsocket.RSocketFactory; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.config.Customizer; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor; +import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; +import org.springframework.security.rsocket.metadata.BearerTokenMetadata; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import reactor.core.publisher.Mono; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@ContextConfiguration +@RunWith(SpringRunner.class) +public class JwtITests { + @Autowired + RSocketMessageHandler handler; + + @Autowired + PayloadSocketAcceptorInterceptor interceptor; + + @Autowired + ServerController controller; + + @Autowired + ReactiveJwtDecoder decoder; + + private CloseableChannel server; + + private RSocketRequester requester; + + @Before + public void setup() { + this.server = RSocketFactory.receive() + .frameDecoder(PayloadDecoder.ZERO_COPY) + .addSocketAcceptorPlugin(this.interceptor) + .acceptor(this.handler.responder()) + .transport(TcpServerTransport.create("localhost", 7000)) + .start() + .block(); + } + + @After + public void dispose() { + this.requester.rsocket().dispose(); + this.server.dispose(); + this.controller.payloads.clear(); + } + + @Test + public void routeWhenAuthorized() { + BearerTokenMetadata credentials = + new BearerTokenMetadata("token"); + when(this.decoder.decode(any())).thenReturn(Mono.just(jwt())); + this.requester = requester() + .setupMetadata(credentials.getToken(), BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + String hiRob = this.requester.route("secure.retrieve-mono") + .data("rob") + .retrieveMono(String.class) + .block(); + + assertThat(hiRob).isEqualTo("Hi rob"); + } + + private Jwt jwt() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); + claims.put(IdTokenClaimNames.SUB, "rob"); + claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); + Instant issuedAt = Instant.now(); + Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); + return new Jwt("token", issuedAt, expiresAt, claims, claims); + } + + private RSocketRequester.Builder requester() { + return RSocketRequester.builder() + .rsocketStrategies(this.handler.getRSocketStrategies()); + } + + + @Configuration + @EnableRSocketSecurity + static class Config { + + @Bean + public ServerController controller() { + return new ServerController(); + } + + @Bean + public RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; + } + + @Bean + public RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder() + .encoder(new BasicAuthenticationEncoder()) + .build(); + } + + @Bean + PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { + rsocket + .authorizePayload(authorize -> + authorize + .route("secure.admin.*").authenticated() + .anyRequest().permitAll() + ) + .jwt(Customizer.withDefaults()); + return rsocket.build(); + } + + @Bean + ReactiveJwtDecoder jwtDecoder() { + return mock(ReactiveJwtDecoder.class); + } + } + + @Controller + static class ServerController { + private List payloads = new ArrayList<>(); + + @MessageMapping("**") + String connect(String payload) { + return "Hi " + payload; + } + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java new file mode 100644 index 0000000000..f6883310fe --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.annotation.rsocket; + +import io.rsocket.RSocketFactory; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.rsocket.EnableRSocketSecurity; +import org.springframework.security.config.annotation.rsocket.RSocketSecurity; +import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor; +import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; +import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * @author Rob Winch + */ +@ContextConfiguration +@RunWith(SpringRunner.class) +public class RSocketMessageHandlerConnectionITests { + @Autowired + RSocketMessageHandler handler; + + @Autowired + PayloadSocketAcceptorInterceptor interceptor; + + @Autowired + ServerController controller; + + private CloseableChannel server; + + private RSocketRequester requester; + + @Before + public void setup() { + this.server = RSocketFactory.receive() + .frameDecoder(PayloadDecoder.ZERO_COPY) + .addSocketAcceptorPlugin(this.interceptor) + .acceptor(this.handler.responder()) + .transport(TcpServerTransport.create("localhost", 7000)) + .start() + .block(); + } + + @After + public void dispose() { + this.requester.rsocket().dispose(); + this.server.dispose(); + this.controller.payloads.clear(); + } + + @Test + public void routeWhenAuthorized() { + UsernamePasswordMetadata credentials = + new UsernamePasswordMetadata("user", "password"); + this.requester = requester() + .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + String hiRob = this.requester.route("secure.retrieve-mono") + .data("rob") + .retrieveMono(String.class) + .block(); + + assertThat(hiRob).isEqualTo("Hi rob"); + } + + @Test + public void routeWhenNotAuthorized() { + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); + this.requester = requester() + .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + assertThatCode(() -> this.requester.route("secure.admin.retrieve-mono") + .data("data") + .retrieveMono(String.class) + .block()) + .isInstanceOf(ApplicationErrorException.class); + } + + @Test + public void routeWhenStreamCredentialsAuthorized() { + UsernamePasswordMetadata connectCredentials = new UsernamePasswordMetadata("user", "password"); + this.requester = requester() + .setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + String hiRob = this.requester.route("secure.admin.retrieve-mono") + .metadata(new UsernamePasswordMetadata("admin", "password"), UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data("rob") + .retrieveMono(String.class) + .block(); + + assertThat(hiRob).isEqualTo("Hi rob"); + } + + @Test + public void connectWhenNotAuthenticated() { + this.requester = requester() + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + assertThatCode(() -> this.requester.route("retrieve-mono") + .data("data") + .retrieveMono(String.class) + .block()) + .isNotNull(); + // FIXME: https://github.com/rsocket/rsocket-java/issues/686 + // .isInstanceOf(RejectedSetupException.class); + } + + @Test + public void connectWhenNotAuthorized() { + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("evil", "password"); + this.requester = requester() + .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + + assertThatCode(() -> this.requester.route("retrieve-mono") + .data("data") + .retrieveMono(String.class) + .block()) + .isNotNull(); +// FIXME: https://github.com/rsocket/rsocket-java/issues/686 +// .isInstanceOf(RejectedSetupException.class); + } + + private RSocketRequester.Builder requester() { + return RSocketRequester.builder() + .rsocketStrategies(this.handler.getRSocketStrategies()); + } + + + @Configuration + @EnableRSocketSecurity + static class Config { + + @Bean + public ServerController controller() { + return new ServerController(); + } + + @Bean + public RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; + } + + @Bean + public RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder() + .encoder(new BasicAuthenticationEncoder()) + .build(); + } + + @Bean + MapReactiveUserDetailsService uds() { + UserDetails admin = User.withDefaultPasswordEncoder() + .username("admin") + .password("password") + .roles("USER", "ADMIN", "SETUP") + .build(); + UserDetails user = User.withDefaultPasswordEncoder() + .username("user") + .password("password") + .roles("USER", "SETUP") + .build(); + + UserDetails evil = User.withDefaultPasswordEncoder() + .username("evil") + .password("password") + .roles("EVIL") + .build(); + return new MapReactiveUserDetailsService(admin, user, evil); + } + + @Bean + PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { + rsocket + .authorizePayload(authorize -> + authorize + .setup().hasRole("SETUP") + .route("secure.admin.*").hasRole("ADMIN") + .route("secure.**").hasRole("USER") + .anyRequest().permitAll() + ) + .basicAuthentication(Customizer.withDefaults()); + return rsocket.build(); + } + } + + @Controller + static class ServerController { + private List payloads = new ArrayList<>(); + + @MessageMapping("**") + String connect(String payload) { + return "Hi " + payload; + } + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java new file mode 100644 index 0000000000..3cbcc2117d --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java @@ -0,0 +1,312 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.rsocket; + +import io.rsocket.RSocketFactory; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.rsocket.EnableRSocketSecurity; +import org.springframework.security.config.annotation.rsocket.RSocketSecurity; +import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.rsocket.interceptor.PayloadSocketAcceptorInterceptor; +import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; +import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; +import org.springframework.stereotype.Controller; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * @author Rob Winch + */ +@ContextConfiguration +@RunWith(SpringRunner.class) +public class RSocketMessageHandlerITests { + @Autowired + RSocketMessageHandler handler; + + @Autowired + PayloadSocketAcceptorInterceptor interceptor; + + @Autowired + ServerController controller; + + private CloseableChannel server; + + private RSocketRequester requester; + + @Before + public void setup() { + this.server = RSocketFactory.receive() + .frameDecoder(PayloadDecoder.ZERO_COPY) + .addSocketAcceptorPlugin(this.interceptor) + .acceptor(this.handler.responder()) + .transport(TcpServerTransport.create("localhost", 7000)) + .start() + .block(); + + this.requester = RSocketRequester.builder() + // .rsocketFactory(factory -> factory.addRequesterPlugin(payloadInterceptor)) + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", 7000) + .block(); + } + + @After + public void dispose() { + this.requester.rsocket().dispose(); + this.server.dispose(); + this.controller.payloads.clear(); + } + + @Test + public void retrieveMonoWhenSecureThenDenied() throws Exception { + String data = "rob"; + assertThatCode(() -> this.requester.route("secure.retrieve-mono") + .data(data) + .retrieveMono(String.class) + .block() + ).isInstanceOf(ApplicationErrorException.class); + assertThat(this.controller.payloads).isEmpty(); + } + + @Test + public void retrieveMonoWhenAuthenticationFailedThenException() throws Exception { + String data = "rob"; + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("invalid", "password"); + assertThatCode(() -> this.requester.route("secure.retrieve-mono") + .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data(data) + .retrieveMono(String.class) + .block() + ).isInstanceOf(ApplicationErrorException.class); + assertThat(this.controller.payloads).isEmpty(); + } + + @Test + public void retrieveMonoWhenAuthorizedThenGranted() throws Exception { + String data = "rob"; + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); + String hiRob = this.requester.route("secure.retrieve-mono") + .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data(data) + .retrieveMono(String.class) + .block(); + + assertThat(hiRob).isEqualTo("Hi rob"); + assertThat(this.controller.payloads).containsOnly(data); + } + + @Test + public void retrieveMonoWhenPublicThenGranted() throws Exception { + String data = "rob"; + String hiRob = this.requester.route("retrieve-mono") + .data(data) + .retrieveMono(String.class) + .block(); + + assertThat(hiRob).isEqualTo("Hi rob"); + assertThat(this.controller.payloads).containsOnly(data); + } + + @Test + public void retrieveFluxWhenDataFluxAndSecureThenDenied() throws Exception { + Flux data = Flux.just("a", "b", "c"); + assertThatCode(() -> this.requester.route("secure.secure.retrieve-flux") + .data(data, String.class) + .retrieveFlux(String.class) + .collectList() + .block()).isInstanceOf( + ApplicationErrorException.class); + + assertThat(this.controller.payloads).isEmpty(); + } + + @Test + public void retrieveFluxWhenDataFluxAndPublicThenGranted() throws Exception { + Flux data = Flux.just("a", "b", "c"); + List hi = this.requester.route("retrieve-flux") + .data(data, String.class) + .retrieveFlux(String.class) + .collectList() + .block(); + + assertThat(hi).containsOnly("hello a", "hello b", "hello c"); + assertThat(this.controller.payloads).containsOnlyElementsOf(data.collectList().block()); + } + + @Test + public void retrieveFluxWhenDataStringAndSecureThenDenied() throws Exception { + String data = "a"; + assertThatCode(() -> this.requester.route("secure.hello") + .data(data) + .retrieveFlux(String.class) + .collectList() + .block()).isInstanceOf( + ApplicationErrorException.class); + + assertThat(this.controller.payloads).isEmpty(); + } + + @Test + public void retrieveFluxWhenDataStringAndPublicThenGranted() throws Exception { + String data = "a"; + List hi = this.requester.route("retrieve-flux") + .data(data) + .retrieveFlux(String.class) + .collectList() + .block(); + + assertThat(hi).contains("hello a"); + assertThat(this.controller.payloads).containsOnly(data); + } + + @Test + public void sendWhenSecureThenDenied() throws Exception { + String data = "hi"; + this.requester.route("secure.send") + .data(data) + .send() + .block(); + + assertThat(this.controller.payloads).isEmpty(); + } + + @Test + public void sendWhenPublicThenGranted() throws Exception { + String data = "hi"; + this.requester.route("send") + .data(data) + .send() + .block(); + assertThat(this.controller.awaitPayloads()).containsOnly("hi"); + } + + @Configuration + @EnableRSocketSecurity + static class Config { + + @Bean + public ServerController controller() { + return new ServerController(); + } + + @Bean + public RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; + } + + @Bean + public RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder() + .encoder(new BasicAuthenticationEncoder()) + .build(); + } + + @Bean + MapReactiveUserDetailsService uds() { + UserDetails rob = User.withDefaultPasswordEncoder() + .username("rob") + .password("password") + .roles("USER", "ADMIN") + .build(); + UserDetails rossen = User.withDefaultPasswordEncoder() + .username("rossen") + .password("password") + .roles("USER") + .build(); + return new MapReactiveUserDetailsService(rob, rossen); + } + + @Bean + PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { + rsocket + .authorizePayload(authorize -> { + authorize + .route("secure.*").authenticated() + .anyRequest().permitAll(); + }) + .basicAuthentication(Customizer.withDefaults()); + return rsocket.build(); + } + } + + @Controller + static class ServerController { + private List payloads = new ArrayList<>(); + + @MessageMapping({"secure.retrieve-mono", "retrieve-mono"}) + String retrieveMono(String payload) { + add(payload); + return "Hi " + payload; + } + + @MessageMapping({"secure.retrieve-flux", "retrieve-flux"}) + Flux retrieveFlux(Flux payload) { + return payload.doOnNext(this::add) + .map(p -> "hello " + p); + } + + @MessageMapping({"secure.send", "send"}) + Mono send(Flux payload) { + return payload + .doOnNext(this::add) + .then(Mono.fromRunnable(() -> { + doNotifyAll(); + })); + } + + private synchronized void doNotifyAll() { + this.notifyAll(); + } + + private synchronized List awaitPayloads() throws InterruptedException { + this.wait(); + return this.payloads; + } + + private void add(String p) { + this.payloads.add(p); + } + } + +} diff --git a/etc/checkstyle/header.txt b/etc/checkstyle/header.txt index e432c9f5bd..5e5d28b99f 100644 --- a/etc/checkstyle/header.txt +++ b/etc/checkstyle/header.txt @@ -1,5 +1,5 @@ ^\Q/*\E$ -^\Q * Copyright\E (\d{4}\-\d{4} the original author or authors\.|(\d{4}, )*(\d{4}) Acegi Technology Pty Limited)$ +^\Q * Copyright\E (\d{4}(\-\d{4})? the original author or authors\.|(\d{4}, )*(\d{4}) Acegi Technology Pty Limited)$ ^\Q *\E$ ^\Q * Licensed under the Apache License, Version 2.0 (the "License");\E$ ^\Q * you may not use this file except in compliance with the License.\E$ @@ -13,4 +13,4 @@ ^\Q * See the License for the specific language governing permissions and\E$ ^\Q * limitations under the License.\E$ ^\Q */\E$ -^.*$ \ No newline at end of file +^.*$ diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle index b221730670..76549e80af 100644 --- a/gradle/dependency-management.gradle +++ b/gradle/dependency-management.gradle @@ -1,15 +1,17 @@ if (!project.hasProperty('reactorVersion')) { - ext.reactorVersion = 'Dysprosium-M3' + ext.reactorVersion = 'Dysprosium-RC1' } if (!project.hasProperty('springVersion')) { - ext.springVersion = '5.2.0.RC1' + ext.springVersion = '5.2.0.BUILD-SNAPSHOT' } if (!project.hasProperty('springDataVersion')) { ext.springDataVersion = 'Moore-RC2' } +ext.rsocketVersion = '1.0.0-RC3' + dependencyManagement { imports { mavenBom "io.projectreactor:reactor-bom:${reactorVersion}" @@ -71,6 +73,8 @@ dependencyManagement { dependency 'commons-logging:commons-logging:1.2' dependency 'dom4j:dom4j:1.6.1' dependency 'io.projectreactor.tools:blockhound:1.0.0.M4' + dependency "io.rsocket:rsocket-core:${rsocketVersion}" + dependency "io.rsocket:rsocket-transport-netty:${rsocketVersion}" dependency 'javax.activation:activation:1.1.1' dependency 'javax.annotation:jsr250-api:1.0' dependency 'javax.inject:javax.inject:1' diff --git a/rsocket/spring-security-rsocket.gradle b/rsocket/spring-security-rsocket.gradle new file mode 100644 index 0000000000..fa508c2760 --- /dev/null +++ b/rsocket/spring-security-rsocket.gradle @@ -0,0 +1,9 @@ +apply plugin: 'io.spring.convention.spring-module' + +dependencies { + compile project(':spring-security-core') + compile 'io.rsocket:rsocket-core' + optional project(':spring-security-oauth2-resource-server') + optional 'org.springframework:spring-messaging' + testCompile 'io.projectreactor:reactor-test' +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java new file mode 100644 index 0000000000..eb7d3f9494 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/ContextPayloadInterceptorChain.java @@ -0,0 +1,96 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import java.util.List; +import java.util.ListIterator; + +/** + * A {@link PayloadInterceptorChain} which exposes the Reactor {@link Context} via a member variable. + * This class is not Thread safe, so a new instance must be created for each Thread. + * + * Internally {@code ContextPayloadInterceptorChain} is used to ensure that the Reactor + * {@code Context} is captured so it can be transferred to subscribers outside of this + * {@code Context} in {@code PayloadSocketAcceptor}. + * + * @author Rob Winch + * @since 5.2 + * @see PayloadSocketAcceptor + */ +class ContextPayloadInterceptorChain implements PayloadInterceptorChain { + + private final PayloadInterceptor currentInterceptor; + + private final ContextPayloadInterceptorChain next; + + private Context context; + + ContextPayloadInterceptorChain(List interceptors) { + if (interceptors == null) { + throw new IllegalArgumentException("interceptors cannot be null"); + } + if (interceptors.isEmpty()) { + throw new IllegalArgumentException("interceptors cannot be empty"); + } + ContextPayloadInterceptorChain interceptor = init(interceptors); + this.currentInterceptor = interceptor.currentInterceptor; + this.next = interceptor.next; + } + + private static ContextPayloadInterceptorChain init(List interceptors) { + ContextPayloadInterceptorChain interceptor = new ContextPayloadInterceptorChain(null, null); + ListIterator iterator = interceptors.listIterator(interceptors.size()); + while (iterator.hasPrevious()) { + interceptor = new ContextPayloadInterceptorChain(iterator.previous(), interceptor); + } + return interceptor; + } + + private ContextPayloadInterceptorChain(PayloadInterceptor currentInterceptor, ContextPayloadInterceptorChain next) { + this.currentInterceptor = currentInterceptor; + this.next = next; + } + + public Mono next(PayloadExchange exchange) { + return Mono.defer(() -> + shouldIntercept() ? + this.currentInterceptor.intercept(exchange, this.next) : + Mono.subscriberContext() + .doOnNext(c -> this.context = c) + .then() + ); + } + + Context getContext() { + if (this.next == null) { + return this.context; + } + return this.next.getContext(); + } + + private boolean shouldIntercept() { + return this.currentInterceptor != null && this.next != null; + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[currentInterceptor=" + this.currentInterceptor + "]"; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java new file mode 100644 index 0000000000..9a289cd2e6 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/DefaultPayloadExchange.java @@ -0,0 +1,70 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.Payload; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Default implementation of {@link PayloadExchange} + * + * @author Rob Winch + * @since 5.2 + */ +public class DefaultPayloadExchange implements PayloadExchange { + + private final PayloadExchangeType type; + + private final Payload payload; + + private final MimeType metadataMimeType; + + private final MimeType dataMimeType; + + public DefaultPayloadExchange(PayloadExchangeType type, Payload payload, MimeType metadataMimeType, + MimeType dataMimeType) { + Assert.notNull(type, "type cannot be null"); + Assert.notNull(payload, "payload cannot be null"); + Assert.notNull(metadataMimeType, "metadataMimeType cannot be null"); + Assert.notNull(dataMimeType, "dataMimeType cannot be null"); + this.type = type; + this.payload = payload; + this.metadataMimeType = metadataMimeType; + this.dataMimeType = dataMimeType; + } + + @Override + public PayloadExchangeType getType() { + return this.type; + } + + @Override + public Payload getPayload() { + return this.payload; + } + + @Override + public MimeType getMetadataMimeType() { + return this.metadataMimeType; + } + + @Override + public MimeType getDataMimeType() { + return this.dataMimeType; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java new file mode 100644 index 0000000000..7cf8ca4dca --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchange.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.Payload; +import org.springframework.util.MimeType; + +/** + * Contract for a Payload interaction. + * + * @author Rob Winch + * @since 5.2 + */ +public interface PayloadExchange { + PayloadExchangeType getType(); + + Payload getPayload(); + + MimeType getDataMimeType(); + + MimeType getMetadataMimeType(); +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java new file mode 100644 index 0000000000..455b0e96ef --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadExchangeType.java @@ -0,0 +1,80 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +/** + * The {@link PayloadExchange} type + * + * @author Rob Winch + * @since 5.2 + */ +public enum PayloadExchangeType { + /** + * The Setup. Can + * be used to determine if a Payload is part of the connection + */ + SETUP(false), + + /** + * A Fire and Forget exchange. + */ + FIRE_AND_FORGET(true), + + /** + * A Request + * Response exchange. + */ + REQUEST_RESPONSE(true), + + /** + * A Request Stream + * exchange. This is only represents the request portion. The {@link #PAYLOAD} type + * represents the data that submitted. + */ + REQUEST_STREAM(true), + + /** + * A Request + * Channel exchange. + */ + REQUEST_CHANNEL(true), + + /** + * A Payload exchange. + */ + PAYLOAD(false), + + /** + * A Metadata Push + * exchange. + */ + METADATA_PUSH(true); + + private final boolean isRequest; + + PayloadExchangeType(boolean isRequest) { + this.isRequest = isRequest; + } + + /** + * Determines if this exchange is a type of request (i.e. the initial frame). + * @return true if it is a request, else false + */ + public boolean isRequest() { + return this.isRequest; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java new file mode 100644 index 0000000000..8984ef5417 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import reactor.core.publisher.Mono; + +/** + * Contract for interception-style, chained processing of Payloads that may + * be used to implement cross-cutting, application-agnostic requirements such + * as security, timeouts, and others. + * + * @author Rob Winch + * @since 5.2 + */ +public interface PayloadInterceptor { + /** + * Process the Web request and (optionally) delegate to the next + * {@code PayloadInterceptor} through the given {@link PayloadInterceptorChain}. + * @param exchange the current payload exchange + * @param chain provides a way to delegate to the next interceptor + * @return {@code Mono} to indicate when payload processing is complete + */ + Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain); +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java new file mode 100644 index 0000000000..97c6ef7487 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorChain.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import reactor.core.publisher.Mono; + +/** + * Contract to allow a {@link PayloadInterceptor} to delegate to the next in the chain. + * * + * @author Rob Winch + * @since 5.2 + */ +public interface PayloadInterceptorChain { + /** + * Process the payload exchange. + * @param exchange the current server exchange + * @return {@code Mono} to indicate when request processing is complete + */ + Mono next(PayloadExchange exchange); +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java new file mode 100644 index 0000000000..8a32f0faef --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocket.java @@ -0,0 +1,140 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.ResponderRSocket; +import io.rsocket.util.RSocketProxy; +import org.reactivestreams.Publisher; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import java.util.List; + +/** + * Combines the {@link PayloadInterceptor} with a {@link ResponderRSocket} + * @author Rob Winch + * @since 5.2 + */ +class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket { + private final List interceptors; + + private final MimeType metadataMimeType; + + private final MimeType dataMimeType; + + private final Context context; + + PayloadInterceptorRSocket(RSocket delegate, + List interceptors, MimeType metadataMimeType, + MimeType dataMimeType) { + this(delegate, interceptors, metadataMimeType, dataMimeType, Context.empty()); + } + + PayloadInterceptorRSocket(RSocket delegate, + List interceptors, MimeType metadataMimeType, + MimeType dataMimeType, Context context) { + super(delegate); + this.metadataMimeType = metadataMimeType; + this.dataMimeType = dataMimeType; + if (delegate == null) { + throw new IllegalArgumentException("delegate cannot be null"); + } + if (interceptors == null) { + throw new IllegalArgumentException("interceptors cannot be null"); + } + if (interceptors.isEmpty()) { + throw new IllegalArgumentException("interceptors cannot be empty"); + } + this.interceptors = interceptors; + this.context = context; + } + + @Override + public Mono fireAndForget(Payload payload) { + return intercept(PayloadExchangeType.FIRE_AND_FORGET, payload) + .flatMap(context -> + this.source.fireAndForget(payload) + .subscriberContext(context) + ); + } + + @Override + public Mono requestResponse(Payload payload) { + return intercept(PayloadExchangeType.REQUEST_RESPONSE, payload) + .flatMap(context -> + this.source.requestResponse(payload) + .subscriberContext(context) + ); + } + + @Override + public Flux requestStream(Payload payload) { + return intercept(PayloadExchangeType.REQUEST_STREAM, payload) + .flatMapMany(context -> + this.source.requestStream(payload) + .subscriberContext(context) + ); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst((signal, innerFlux) -> { + Payload firstPayload = signal.get(); + return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload) + .flatMapMany(context -> + innerFlux + .skip(1) + .flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) + .transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads)) + .transform(securedPayloads -> this.source.requestChannel(securedPayloads)) + .subscriberContext(context) + ); + }); + } + + @Override + public Mono metadataPush(Payload payload) { + return intercept(PayloadExchangeType.METADATA_PUSH, payload) + .flatMap(c -> this.source + .metadataPush(payload) + .subscriberContext(c) + ); + } + + private Mono intercept(PayloadExchangeType type, Payload payload) { + return Mono.defer(() -> { + ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors); + DefaultPayloadExchange exchange = new DefaultPayloadExchange(type, payload, + this.metadataMimeType, this.dataMimeType); + return chain.next(exchange) + .then(Mono.fromCallable(() -> chain.getContext())) + .defaultIfEmpty(Context.empty()) + .subscriberContext(this.context); + }); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[source=" + this.source + ",interceptors=" + + this.interceptors + "]"; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java new file mode 100644 index 0000000000..333a268f98 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptor.java @@ -0,0 +1,99 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.metadata.WellKnownMimeType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import java.util.List; + +/** + * @author Rob Winch + * @since 5.2 + */ +class PayloadSocketAcceptor implements SocketAcceptor { + private final SocketAcceptor delegate; + + private final List interceptors; + + @Nullable + private MimeType defaultDataMimeType; + + private MimeType defaultMetadataMimeType = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + PayloadSocketAcceptor(SocketAcceptor delegate, List interceptors) { + Assert.notNull(delegate, "delegate cannot be null"); + if (interceptors == null) { + throw new IllegalArgumentException("interceptors cannot be null"); + } + if (interceptors.isEmpty()) { + throw new IllegalArgumentException("interceptors cannot be empty"); + } + this.delegate = delegate; + this.interceptors = interceptors; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + MimeType dataMimeType = parseMimeType(setup.dataMimeType(), this.defaultDataMimeType); + Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value"); + + MimeType metadataMimeType = parseMimeType(setup.metadataMimeType(), this.defaultMetadataMimeType); + Assert.notNull(metadataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); + + // FIXME do we want to make the sendingSocket available in the PayloadExchange + return intercept(setup, dataMimeType, metadataMimeType) + .flatMap(ctx -> this.delegate.accept(setup, sendingSocket) + .map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx)) + ); + } + + private Mono intercept(Payload payload, MimeType dataMimeType, MimeType metadataMimeType) { + return Mono.defer(() -> { + ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors); + DefaultPayloadExchange exchange = new DefaultPayloadExchange(PayloadExchangeType.SETUP, payload, + metadataMimeType, dataMimeType); + return chain.next(exchange) + .then(Mono.fromCallable(() -> chain.getContext())) + .defaultIfEmpty(Context.empty()); + }); + } + + private MimeType parseMimeType(String str, MimeType defaultMimeType) { + return StringUtils.hasText(str) ? MimeTypeUtils.parseMimeType(str) : defaultMimeType; + } + + public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) { + this.defaultDataMimeType = defaultDataMimeType; + } + + public void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) { + Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null"); + this.defaultMetadataMimeType = defaultMetadataMimeType; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java new file mode 100644 index 0000000000..35fdb36e14 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.SocketAcceptor; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.plugins.SocketAcceptorInterceptor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +import java.util.List; + +/** + * A {@link SocketAcceptorInterceptor} that applies the {@link PayloadInterceptor}s + * + * @author Rob Winch + * @since 5.2 + */ +public class PayloadSocketAcceptorInterceptor implements SocketAcceptorInterceptor { + + private final List interceptors; + + @Nullable + private MimeType defaultDataMimeType; + + private MimeType defaultMetadataMimeType = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + public PayloadSocketAcceptorInterceptor(List interceptors) { + this.interceptors = interceptors; + } + + @Override + public SocketAcceptor apply(SocketAcceptor socketAcceptor) { + PayloadSocketAcceptor acceptor = new PayloadSocketAcceptor( + socketAcceptor, this.interceptors); + acceptor.setDefaultDataMimeType(this.defaultDataMimeType); + acceptor.setDefaultMetadataMimeType(this.defaultMetadataMimeType); + return acceptor; + } + + public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) { + this.defaultDataMimeType = defaultDataMimeType; + } + + public void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) { + Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null"); + this.defaultMetadataMimeType = defaultMetadataMimeType; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java new file mode 100644 index 0000000000..97f2866d13 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AnonymousPayloadInterceptor.java @@ -0,0 +1,83 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authentication; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadInterceptor; + +import java.util.List; + +/** + * If {@link ReactiveSecurityContextHolder} is empty populates an + * {@code AnonymousAuthenticationToken} + * + * @author Rob Winch + * @since 5.2 + */ +public class AnonymousPayloadInterceptor implements PayloadInterceptor { + + private String key; + private Object principal; + private List authorities; + + + /** + * Creates a filter with a principal named "anonymousUser" and the single authority + * "ROLE_ANONYMOUS". + * + * @param key the key to identify tokens created by this filter + */ + public AnonymousPayloadInterceptor(String key) { + this(key, "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + } + + /** + * @param key key the key to identify tokens created by this filter + * @param principal the principal which will be used to represent anonymous users + * @param authorities the authority list for anonymous users + */ + public AnonymousPayloadInterceptor(String key, Object principal, + List authorities) { + Assert.hasLength(key, "key cannot be null or empty"); + Assert.notNull(principal, "Anonymous authentication principal must be set"); + Assert.notNull(authorities, "Anonymous authorities must be set"); + this.key = key; + this.principal = principal; + this.authorities = authorities; + } + + + @Override + public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { + return ReactiveSecurityContextHolder.getContext() + .switchIfEmpty(Mono.defer(() -> { + AnonymousAuthenticationToken authentication = new AnonymousAuthenticationToken( + this.key, this.principal, this.authorities); + return chain.next(exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .then(Mono.empty()); + })) + .flatMap(securityContext -> chain.next(exchange)); + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java new file mode 100644 index 0000000000..1a988c7e9c --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/AuthenticationPayloadInterceptor.java @@ -0,0 +1,74 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authentication; + +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadInterceptor; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Uses the provided {@code ReactiveAuthenticationManager} to authenticate a Payload. If + * authentication is successful, then the result is added to + * {@link ReactiveSecurityContextHolder}. + * + * @author Rob Winch + * @since 5.2 + */ +public class AuthenticationPayloadInterceptor implements PayloadInterceptor { + + private final ReactiveAuthenticationManager authenticationManager; + + private PayloadExchangeAuthenticationConverter authenticationConverter = + new BasicAuthenticationPayloadExchangeConverter(); + + /** + * Creates a new instance + * @param authenticationManager the manager to use. Cannot be null + */ + public AuthenticationPayloadInterceptor(ReactiveAuthenticationManager authenticationManager) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + this.authenticationManager = authenticationManager; + } + + /** + * Sets the convert to be used + * @param authenticationConverter + */ + public void setAuthenticationConverter( + PayloadExchangeAuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { + return this.authenticationConverter.convert(exchange) + .switchIfEmpty(chain.next(exchange).then(Mono.empty())) + .flatMap(a -> this.authenticationManager.authenticate(a)) + .flatMap(a -> onAuthenticationSuccess(chain.next(exchange), a)); + } + + private Mono onAuthenticationSuccess(Mono payload, Authentication authentication) { + return payload + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); + } + +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java new file mode 100644 index 0000000000..c4bce298aa --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BasicAuthenticationPayloadExchangeConverter.java @@ -0,0 +1,60 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authentication; + +import io.rsocket.metadata.WellKnownMimeType; +import org.springframework.messaging.rsocket.DefaultMetadataExtractor; +import org.springframework.messaging.rsocket.MetadataExtractor; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.metadata.BasicAuthenticationDecoder; +import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Mono; + +/** + * Converts from the {@link PayloadExchange} to a + * {@link UsernamePasswordAuthenticationToken} by extracting + * {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} from the metadata. + * + * @author Rob Winch + * @since 5.2 + */ +public class BasicAuthenticationPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter { + + private MimeType metadataMimetype = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + private MetadataExtractor metadataExtractor = createDefaultExtractor(); + + @Override + public Mono convert(PayloadExchange exchange) { + return Mono.fromCallable(() -> this.metadataExtractor + .extract(exchange.getPayload(), this.metadataMimetype)) + .flatMap(metadata -> Mono.justOrEmpty(metadata.get(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE.toString()))) + .cast(UsernamePasswordMetadata.class) + .map(credentials -> new UsernamePasswordAuthenticationToken(credentials.getUsername(), credentials.getPassword())); + } + + private static MetadataExtractor createDefaultExtractor() { + DefaultMetadataExtractor result = new DefaultMetadataExtractor(new BasicAuthenticationDecoder()); + result.metadataToExtract(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE, UsernamePasswordMetadata.class, (String) null); + return result; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java new file mode 100644 index 0000000000..cc9db71dfc --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/BearerPayloadExchangeConverter.java @@ -0,0 +1,54 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authentication; + +import io.netty.buffer.ByteBuf; +import io.rsocket.metadata.CompositeMetadata; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.metadata.BearerTokenMetadata; +import reactor.core.publisher.Mono; + +import java.nio.charset.StandardCharsets; + +/** + * Converts from the {@link PayloadExchange} to a + * {@link BearerTokenAuthenticationToken} by extracting + * {@link BearerTokenMetadata#BEARER_AUTHENTICATION_MIME_TYPE} from the metadata. + * @author Rob Winch + * @since 5.2 + */ +public class BearerPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter { + + private static final String BEARER_MIME_TYPE_VALUE = + BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE.toString(); + + @Override + public Mono convert(PayloadExchange exchange) { + ByteBuf metadata = exchange.getPayload().metadata(); + CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + if (BEARER_MIME_TYPE_VALUE.equals(entry.getMimeType())) { + ByteBuf content = entry.getContent(); + String token = content.toString(StandardCharsets.UTF_8); + return Mono.just(new BearerTokenAuthenticationToken(token)); + } + } + return Mono.empty(); + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java new file mode 100644 index 0000000000..2713aeb06a --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authentication/PayloadExchangeAuthenticationConverter.java @@ -0,0 +1,30 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authentication; + +import org.springframework.security.core.Authentication; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import reactor.core.publisher.Mono; + +/** + * Converts from a {@link PayloadExchange} to an {@link Authentication} + * @author Rob Winch + * @since 5.2 + */ +public interface PayloadExchangeAuthenticationConverter { + Mono convert(PayloadExchange exchange); +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java new file mode 100644 index 0000000000..1fc8011594 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/AuthorizationPayloadInterceptor.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authorization; + +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadInterceptor; + +/** + * Provides authorization of the {@link PayloadExchange}. + * + * @author Rob Winch + * @since 5.2 + */ +public class AuthorizationPayloadInterceptor implements PayloadInterceptor { + private final ReactiveAuthorizationManager authorizationManager; + + public AuthorizationPayloadInterceptor( + ReactiveAuthorizationManager authorizationManager) { + Assert.notNull(authorizationManager, "authorizationManager cannot be null"); + this.authorizationManager = authorizationManager; + } + + @Override + public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { + return ReactiveSecurityContextHolder.getContext() + .filter(c -> c.getAuthentication() != null) + .map(SecurityContext::getAuthentication) + .switchIfEmpty(Mono.error(() -> new AuthenticationCredentialsNotFoundException("An Authentication (possibly AnonymousAuthenticationToken) is required."))) + .as(authentication -> this.authorizationManager.verify(authentication, exchange)) + .then(chain.next(exchange)); + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java new file mode 100644 index 0000000000..7fb7096850 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java @@ -0,0 +1,82 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authorization; + +import org.springframework.security.authorization.AuthorizationDecision; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext; +import org.springframework.security.rsocket.util.PayloadExchangeMatcher; +import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry; + +import java.util.ArrayList; +import java.util.List; + +/** + * Maps a @{code List} of {@link PayloadExchangeMatcher} instances to + * @{code ReactiveAuthorizationManager} instances. + * + * @author Rob Winch + * @since 5.2 + */ +public class PayloadExchangeMatcherReactiveAuthorizationManager implements ReactiveAuthorizationManager { + private final List>> mappings; + + private PayloadExchangeMatcherReactiveAuthorizationManager(List>> mappings) { + Assert.notEmpty(mappings, "mappings cannot be null"); + this.mappings = mappings; + } + + @Override + public Mono check(Mono authentication, PayloadExchange exchange) { + return Flux.fromIterable(this.mappings) + .concatMap(mapping -> mapping.getMatcher().matches(exchange) + .filter(PayloadExchangeMatcher.MatchResult::isMatch) + .map(r -> r.getVariables()) + .flatMap(variables -> mapping.getEntry() + .check(authentication, new PayloadExchangeAuthorizationContext(exchange, variables)) + ) + ) + .next() + .switchIfEmpty(Mono.fromCallable(() -> new AuthorizationDecision(false))); + } + + public static PayloadExchangeMatcherReactiveAuthorizationManager.Builder builder() { + return new PayloadExchangeMatcherReactiveAuthorizationManager.Builder(); + } + + public static class Builder { + private final List>> mappings = new ArrayList<>(); + + private Builder() { + } + + public PayloadExchangeMatcherReactiveAuthorizationManager.Builder add( + PayloadExchangeMatcherEntry> entry) { + this.mappings.add(entry); + return this; + } + + public PayloadExchangeMatcherReactiveAuthorizationManager build() { + return new PayloadExchangeMatcherReactiveAuthorizationManager(this.mappings); + } + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java new file mode 100644 index 0000000000..5085e5a833 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.metadata; + +import org.reactivestreams.Publisher; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDecoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * Decodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} + * + * @author Rob Winch + * @since 5.2 + */ +public class BasicAuthenticationDecoder extends AbstractDecoder { + public BasicAuthenticationDecoder() { + super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE); + } + + @Override + public Flux decode(Publisher input, + ResolvableType elementType, MimeType mimeType, Map hints) { + return Flux.from(input) + .map(DataBuffer::asByteBuffer) + .map(byteBuffer -> { + byte[] sizeBytes = new byte[4]; + byteBuffer.get(sizeBytes); + + int usernameSize = 4; + byte[] usernameBytes = new byte[usernameSize]; + byteBuffer.get(usernameBytes); + byte[] passwordBytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(passwordBytes); + String username = new String(usernameBytes); + String password = new String(passwordBytes); + return new UsernamePasswordMetadata(username, password); + }); + } + + @Override + public Mono decodeToMono(Publisher input, + ResolvableType elementType, MimeType mimeType, Map hints) { + return Mono.from(input) + .map(DataBuffer::asByteBuffer) + .map(byteBuffer -> { + int usernameSize = byteBuffer.getInt(); + byte[] usernameBytes = new byte[usernameSize]; + byteBuffer.get(usernameBytes); + byte[] passwordBytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(passwordBytes); + String username = new String(usernameBytes); + String password = new String(passwordBytes); + return new UsernamePasswordMetadata(username, password); + }); + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java new file mode 100644 index 0000000000..9d088f5a2a --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.metadata; + +import org.reactivestreams.Publisher; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractEncoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * Encodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} + * + * @author Rob Winch + * @since 5.2 + */ +public class BasicAuthenticationEncoder extends + AbstractEncoder { + + public BasicAuthenticationEncoder() { + super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE); + } + + @Override + public Flux encode( + Publisher inputStream, + DataBufferFactory bufferFactory, ResolvableType elementType, + MimeType mimeType, Map hints) { + return Flux.from(inputStream).map(credentials -> + encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); + } + + @Override + public DataBuffer encodeValue(UsernamePasswordMetadata credentials, + DataBufferFactory bufferFactory, ResolvableType valueType, MimeType mimeType, + Map hints) { + String username = credentials.getUsername(); + String password = credentials.getPassword(); + byte[] usernameBytes = username.getBytes(StandardCharsets.UTF_8); + byte[] usernameBytesLengthBytes = ByteBuffer.allocate(4).putInt(usernameBytes.length).array(); + DataBuffer metadata = bufferFactory.allocateBuffer(); + boolean release = true; + try { + metadata.write(usernameBytesLengthBytes); + metadata.write(usernameBytes); + metadata.write(password.getBytes(StandardCharsets.UTF_8)); + release = false; + return metadata; + } finally { + if (release) { + DataBufferUtils.release(metadata); + } + } + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java new file mode 100644 index 0000000000..e252fa21f3 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java @@ -0,0 +1,47 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.springframework.security.rsocket.metadata; + +import org.springframework.http.MediaType; +import org.springframework.util.MimeType; + +/** + * Represents a bearer token that has been encoded into a + * {@link Payload#metadata()}. + * + * @author Rob Winch + * @since 5.2 + */ +public class BearerTokenMetadata { + /** + * Represents a bearer token which is encoded as a String. + * + * See rsocket/rsocket#272 + */ + public static final MimeType BEARER_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.bearer.v0"); + + private final String token; + + public BearerTokenMetadata(String token) { + this.token = token; + } + + public String getToken() { + return this.token; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java new file mode 100644 index 0000000000..e99e23aa40 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java @@ -0,0 +1,55 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.metadata; + +import io.rsocket.Payload; +import org.springframework.http.MediaType; +import org.springframework.util.MimeType; + +/** + * Represents a username and password that have been encoded into a + * {@link Payload#metadata()}. + * + * @author Rob Winch + * @since 5.2 + */ +public final class UsernamePasswordMetadata { + /** + * Represents a username password which is encoded as + * {@code ${username-bytes-length}${username-bytes}${password-bytes}}. + * + * See rsocket/rsocket#272 + */ + public static final MimeType BASIC_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.basic.v0"); + + private final String username; + + private final String password; + + public UsernamePasswordMetadata(String username, String password) { + this.username = username; + this.password = password; + } + + public String getUsername() { + return this.username; + } + + public String getPassword() { + return this.password; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java new file mode 100644 index 0000000000..ac01e07f73 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeAuthorizationContext.java @@ -0,0 +1,48 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.util; + +import org.springframework.security.rsocket.interceptor.PayloadExchange; + +import java.util.Collections; +import java.util.Map; + +/** + * @author Rob Winch + * @since 5.2 + */ +public class PayloadExchangeAuthorizationContext { + private final PayloadExchange exchange; + private final Map variables; + + public PayloadExchangeAuthorizationContext(PayloadExchange exchange) { + this(exchange, Collections.emptyMap()); + } + + public PayloadExchangeAuthorizationContext(PayloadExchange exchange, Map variables) { + this.exchange = exchange; + this.variables = variables; + } + + public PayloadExchange getExchange() { + return this.exchange; + } + + public Map getVariables() { + return Collections.unmodifiableMap(this.variables); + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java new file mode 100644 index 0000000000..d5a368a1f8 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcher.java @@ -0,0 +1,89 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.rsocket.util; + +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * An interface for determining if a {@link PayloadExchangeMatcher} matches. + * @author Rob Winch + * @since 5.2 + */ +public interface PayloadExchangeMatcher { + + /** + * Determines if a request matches or not + * @param exchange + * @return + */ + Mono matches(PayloadExchange exchange); + + /** + * The result of matching + */ + class MatchResult { + private final boolean match; + private final Map variables; + + private MatchResult(boolean match, Map variables) { + this.match = match; + this.variables = variables; + } + + public boolean isMatch() { + return match; + } + + /** + * Gets potential variables and their values + * @return + */ + public Map getVariables() { + return variables; + } + + /** + * Creates an instance of {@link MatchResult} that is a match with no variables + * @return + */ + public static Mono match() { + return match(Collections.emptyMap()); + } + + /** + * + * Creates an instance of {@link MatchResult} that is a match with the specified variables + * @param variables + * @return + */ + public static Mono match(Map variables) { + return Mono.just(new MatchResult(true, variables == null ? null : new HashMap(variables))); + } + + /** + * Creates an instance of {@link MatchResult} that is not a match. + * @return + */ + public static Mono notMatch() { + return Mono.just(new MatchResult(false, Collections.emptyMap())); + } + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java new file mode 100644 index 0000000000..691033c417 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatcherEntry.java @@ -0,0 +1,38 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.util; + +/** + * @author Rob Winch + */ +public class PayloadExchangeMatcherEntry { + private final PayloadExchangeMatcher matcher; + private final T entry; + + public PayloadExchangeMatcherEntry(PayloadExchangeMatcher matcher, T entry) { + this.matcher = matcher; + this.entry = entry; + } + + public PayloadExchangeMatcher getMatcher() { + return this.matcher; + } + + public T getEntry() { + return this.entry; + } +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java new file mode 100644 index 0000000000..9202949ac3 --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/PayloadExchangeMatchers.java @@ -0,0 +1,57 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.util; + +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadExchangeType; +import reactor.core.publisher.Mono; + +/** + * @author Rob Winch + */ +public abstract class PayloadExchangeMatchers { + + public static PayloadExchangeMatcher setup() { + return new PayloadExchangeMatcher() { + public Mono matches(PayloadExchange exchange) { + return PayloadExchangeType.SETUP.equals(exchange.getType()) ? + MatchResult.match() : + MatchResult.notMatch(); + } + }; + } + + public static PayloadExchangeMatcher anyRequest() { + return new PayloadExchangeMatcher() { + public Mono matches(PayloadExchange exchange) { + return exchange.getType().isRequest() ? + MatchResult.match() : + MatchResult.notMatch(); + } + }; + } + + public static PayloadExchangeMatcher anyExchange() { + return new PayloadExchangeMatcher() { + public Mono matches(PayloadExchange exchange) { + return MatchResult.match(); + } + }; + } + + private PayloadExchangeMatchers() {} +} diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java new file mode 100644 index 0000000000..0b711d212b --- /dev/null +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcher.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.util; + +import org.springframework.messaging.rsocket.MetadataExtractor; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.util.Assert; +import org.springframework.util.RouteMatcher; +import reactor.core.publisher.Mono; + +import java.util.Map; +import java.util.Optional; + +/** + * FIXME: Pay attention to the package this goes into. It requires spring-messaging for + * the MetadataExtractor. + * + * @author Rob Winch + * @since 5.2 + */ +public class RoutePayloadExchangeMatcher implements PayloadExchangeMatcher { + + private final String pattern; + + private final MetadataExtractor metadataExtractor; + + private final RouteMatcher routeMatcher; + + public RoutePayloadExchangeMatcher(MetadataExtractor metadataExtractor, + RouteMatcher routeMatcher, String pattern) { + Assert.notNull(pattern, "pattern cannot be null"); + this.metadataExtractor = metadataExtractor; + this.routeMatcher = routeMatcher; + this.pattern = pattern; + } + + @Override + public Mono matches(PayloadExchange exchange) { + Map metadata = this.metadataExtractor + .extract(exchange.getPayload(), exchange.getMetadataMimeType()); + return Optional.ofNullable((String) metadata.get(MetadataExtractor.ROUTE_KEY)) + .map(routeValue -> this.routeMatcher.parseRoute(routeValue)) + .map(route -> this.routeMatcher.matchAndExtract(this.pattern, route)) + .map(v -> MatchResult.match(v)) + .orElse(MatchResult.notMatch()); + } +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java new file mode 100644 index 0000000000..86a00c93d0 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java @@ -0,0 +1,108 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.authentication.AnonymousPayloadInterceptor; + +import java.util.List; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class AnonymousPayloadInterceptorTests { + @Mock + private PayloadExchange exchange; + + private AnonymousPayloadInterceptor interceptor; + + @Before + public void setup() { + this.interceptor = new AnonymousPayloadInterceptor("key"); + } + + @Test + public void constructorKeyWhenKeyNullThenException() { + String key = null; + assertThatCode(() -> new AnonymousPayloadInterceptor(key)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorKeyPrincipalAuthoritiesWhenKeyNullThenException() { + String key = null; + assertThatCode(() -> new AnonymousPayloadInterceptor(key, "principal", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorKeyPrincipalAuthoritiesWhenPrincipalNullThenException() { + Object principal = null; + assertThatCode(() -> new AnonymousPayloadInterceptor("key", principal, + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorKeyPrincipalAuthoritiesWhenAuthoritiesNullThenException() { + List authorities = null; + assertThatCode(() -> new AnonymousPayloadInterceptor("key", "principal", + authorities)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void interceptWhenNoAuthenticationThenAnonymousAuthentication() { + AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain(); + + this.interceptor.intercept(this.exchange, chain).block(); + + Authentication authentication = chain.getAuthentication(); + + assertThat(authentication).isInstanceOf(AnonymousAuthenticationToken.class); + } + + @Test + public void interceptWhenAuthenticationThenOriginalAuthentication() { + AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain(); + TestingAuthenticationToken expected = + new TestingAuthenticationToken("test", "password"); + + this.interceptor.intercept(this.exchange, chain) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expected)) + .block(); + + Authentication authentication = chain.getAuthentication(); + + assertThat(authentication).isEqualTo(expected); + } +} \ No newline at end of file diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java new file mode 100644 index 0000000000..2f5480ab66 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java @@ -0,0 +1,45 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.authentication; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import reactor.core.publisher.Mono; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.security.rsocket.interceptor.PayloadExchange; + +/** + * @author Rob Winch + */ +class AuthenticationPayloadInterceptorChain implements PayloadInterceptorChain { + private Authentication authentication; + + @Override + public Mono next(PayloadExchange exchange) { + return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) + .doOnNext(a -> this.setAuthentication(a)).then(); + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public void setAuthentication(Authentication authentication) { + this.authentication = authentication; + } +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java new file mode 100644 index 0000000000..e1ed44301c --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.authentication; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadataFlyweight; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.DefaultPayload; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.MediaType; +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.rsocket.interceptor.PayloadExchangeType; +import org.springframework.security.rsocket.interceptor.authentication.AuthenticationPayloadInterceptor; +import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; +import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; +import org.springframework.security.rsocket.interceptor.DefaultPayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.security.rsocket.interceptor.PayloadExchange; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class AuthenticationPayloadInterceptorTests { + static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + @Mock + ReactiveAuthenticationManager authenticationManager; + + @Captor + ArgumentCaptor authenticationArg; + + @Test + public void constructorWhenAuthenticationManagerNullThenException() { + assertThatCode(() -> new AuthenticationPayloadInterceptor(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void interceptWhenBasicCredentialsThenAuthenticates() { + AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor( + this.authenticationManager); + PayloadExchange exchange = createExchange(); + TestingAuthenticationToken expectedAuthentication = + new TestingAuthenticationToken("user", "password"); + when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( + expectedAuthentication)); + + AuthenticationPayloadInterceptorChain authenticationPayloadChain = new AuthenticationPayloadInterceptorChain(); + interceptor.intercept(exchange, authenticationPayloadChain) + .block(); + + Authentication authentication = authenticationPayloadChain.getAuthentication(); + + verify(this.authenticationManager).authenticate(this.authenticationArg.capture()); + assertThat(this.authenticationArg.getValue()).isEqualToComparingFieldByField(new UsernamePasswordAuthenticationToken("user", "password")); + assertThat(authentication).isEqualTo(expectedAuthentication); + } + + @Test + public void interceptWhenAuthenticationSuccessThenChainSubscribedOnce() { + AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor( + this.authenticationManager); + + PayloadExchange exchange = createExchange(); + TestingAuthenticationToken expectedAuthentication = + new TestingAuthenticationToken("user", "password"); + when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( + expectedAuthentication)); + + PublisherProbe voidResult = PublisherProbe.empty(); + PayloadInterceptorChain chain = mock(PayloadInterceptorChain.class); + when(chain.next(any())).thenReturn(voidResult.mono()); + + + StepVerifier.create(interceptor.intercept(exchange, chain)) + .then(() -> assertThat(voidResult.subscribeCount()).isEqualTo(1)) + .verifyComplete(); + } + + private Payload createRequestPayload() { + + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); + BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder(); + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + ResolvableType elementType = ResolvableType + .forClass(UsernamePasswordMetadata.class); + MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE; + Map hints = null; + DataBuffer dataBuffer = encoder.encodeValue(credentials, factory, + elementType, mimeType, hints); + + ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + CompositeByteBuf metadata = allocator.compositeBuffer(); + CompositeMetadataFlyweight.encodeAndAddMetadata( + metadata, allocator, mimeType.toString(), NettyDataBufferFactory.toByteBuf(dataBuffer)); + + return DefaultPayload.create(allocator.buffer(), + metadata); + } + + private PayloadExchange createExchange() { + return new DefaultPayloadExchange(PayloadExchangeType.REQUEST_RESPONSE, createRequestPayload(), COMPOSITE_METADATA, + MediaType.APPLICATION_JSON); + } + +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java new file mode 100644 index 0000000000..c2c5098afa --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.authorization; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.rsocket.interceptor.authorization.AuthorizationPayloadInterceptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; +import reactor.util.context.Context; +import org.springframework.security.rsocket.interceptor.PayloadInterceptorChain; +import org.springframework.security.rsocket.interceptor.PayloadExchange; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; +import static org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager.authenticated; +import static org.springframework.security.authorization.AuthorityReactiveAuthorizationManager.hasRole; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class AuthorizationPayloadInterceptorTests { + @Mock + private ReactiveAuthorizationManager authorizationManager; + + @Mock + private PayloadExchange exchange; + + @Mock + private PayloadInterceptorChain chain; + + private PublisherProbe managerResult = PublisherProbe.empty(); + + private PublisherProbe chainResult = PublisherProbe.empty(); + + @Test + public void interceptWhenAuthenticationEmptyAndSubscribedThenException() { + when(this.chain.next(any())).thenReturn(this.chainResult.mono()); + + AuthorizationPayloadInterceptor interceptor = + new AuthorizationPayloadInterceptor(authenticated()); + + StepVerifier.create(interceptor.intercept(this.exchange, this.chain)) + .then(() -> this.chainResult.assertWasNotSubscribed()) + .verifyError(AuthenticationCredentialsNotFoundException.class); + } + + @Test + public void interceptWhenAuthenticationNotSubscribedAndEmptyThenCompletes() { + when(this.chain.next(any())).thenReturn(this.chainResult.mono()); + when(this.authorizationManager.verify(any(), any())) + .thenReturn(this.managerResult.mono()); + + AuthorizationPayloadInterceptor interceptor = + new AuthorizationPayloadInterceptor(this.authorizationManager); + + StepVerifier.create(interceptor.intercept(this.exchange, this.chain)) + .then(() -> this.chainResult.assertWasSubscribed()) + .verifyComplete(); + } + + @Test + public void interceptWhenNotAuthorizedThenException() { + when(this.chain.next(any())).thenReturn(this.chainResult.mono()); + + AuthorizationPayloadInterceptor interceptor = + new AuthorizationPayloadInterceptor(hasRole("USER")); + Context userContext = ReactiveSecurityContextHolder + .withAuthentication(new TestingAuthenticationToken("user", "password")); + + Mono intercept = interceptor.intercept(this.exchange, this.chain) + .subscriberContext(userContext); + + StepVerifier.create(intercept) + .then(() -> this.chainResult.assertWasNotSubscribed()) + .verifyError(AccessDeniedException.class); + } + + @Test + public void interceptWhenAuthorizedThenContinues() { + when(this.chain.next(any())).thenReturn(this.chainResult.mono()); + + AuthorizationPayloadInterceptor interceptor = + new AuthorizationPayloadInterceptor(authenticated()); + Context userContext = ReactiveSecurityContextHolder + .withAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + + Mono intercept = interceptor.intercept(this.exchange, this.chain) + .subscriberContext(userContext); + + StepVerifier.create(intercept) + .then(() -> this.chainResult.assertWasSubscribed()) + .verifyComplete(); + } +} \ No newline at end of file diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java new file mode 100644 index 0000000000..6dc06fc168 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadInterceptorRSocketTests.java @@ -0,0 +1,509 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; +import org.reactivestreams.Publisher; +import org.springframework.http.MediaType; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; +import reactor.test.publisher.TestPublisher; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class PayloadInterceptorRSocketTests { + + static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + @Mock + RSocket delegate; + + @Mock + PayloadInterceptor interceptor; + + @Mock + PayloadInterceptor interceptor2; + + @Mock + Payload payload; + + @Captor + private ArgumentCaptor exchange; + + PublisherProbe voidResult = PublisherProbe.empty(); + + TestPublisher payloadResult = TestPublisher.createCold(); + + private MimeType metadataMimeType = COMPOSITE_METADATA; + + private MimeType dataMimeType = MediaType.APPLICATION_JSON; + + @Test + public void constructorWhenNullDelegateThenException() { + this.delegate = null; + List interceptors = Arrays.asList(this.interceptor); + assertThatCode(() -> { + new PayloadInterceptorRSocket(this.delegate, interceptors, + metadataMimeType, dataMimeType); + }) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenNullInterceptorsThenException() { + List interceptors = null; + assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, + metadataMimeType, dataMimeType)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenEmptyInterceptorsThenException() { + List interceptors = Collections.emptyList(); + assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, + metadataMimeType, dataMimeType)) + .isInstanceOf(IllegalArgumentException.class); + } + + // single interceptor + + @Test + public void fireAndForgetWhenInterceptorCompletesThenDelegateSubscribed() { + when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.fireAndForget(this.payload)) + .then(() -> this.voidResult.assertWasSubscribed()) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void fireAndForgetWhenInterceptorErrorsThenDelegateNotSubscribed() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.fireAndForget(this.payload)) + .then(() -> this.voidResult.assertWasNotSubscribed()) + .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void fireAndForgetWhenSecurityContextThenDelegateContext() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); + when(this.delegate.fireAndForget(any())).thenReturn(Mono.empty()); + + RSocket assertAuthentication = new RSocketProxy(this.delegate) { + @Override + public Mono fireAndForget(Payload payload) { + return assertAuthentication(authentication) + .flatMap(a -> super.fireAndForget(payload)); + } + }; + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + interceptor.fireAndForget(this.payload).block(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).fireAndForget(this.payload); + } + + @Test + public void requestResponseWhenInterceptorCompletesThenDelegateSubscribed() { + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestResponse(this.payload)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).requestResponse(this.payload); + } + + @Test + public void requestResponseWhenInterceptorErrorsThenDelegateNotInvoked() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + assertThatCode(() -> interceptor.requestResponse(this.payload).block()).isEqualTo(expected); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verifyZeroInteractions(this.delegate); + } + + @Test + public void requestResponseWhenSecurityContextThenDelegateContext() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); + when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono()); + + RSocket assertAuthentication = new RSocketProxy(this.delegate) { + @Override + public Mono requestResponse(Payload payload) { + return assertAuthentication(authentication) + .flatMap(a -> super.requestResponse(payload)); + } + }; + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestResponse(this.payload)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).requestResponse(this.payload); + } + + @Test + public void requestStreamWhenInterceptorCompletesThenDelegateSubscribed() { + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestStream(this.payload)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void requestStreamWhenInterceptorErrorsThenDelegateNotSubscribed() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestStream(this.payload)) + .then(() -> this.payloadResult.assertNoSubscribers()) + .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void requestStreamWhenSecurityContextThenDelegateContext() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); + when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux()); + + RSocket assertAuthentication = new RSocketProxy(this.delegate) { + @Override + public Flux requestStream(Payload payload) { + return assertAuthentication(authentication) + .flatMapMany(a -> super.requestStream(payload)); + } + }; + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestStream(this.payload)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).requestStream(this.payload); + } + + @Test + public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() { + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload))) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).requestChannel(any()); + } + + @Test + public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + + StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload))) + .then(() -> this.payloadResult.assertNoSubscribers()) + .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void requestChannelWhenSecurityContextThenDelegateContext() { + Mono payload = Mono.just(this.payload); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); + when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux()); + + RSocket assertAuthentication = new RSocketProxy(this.delegate) { + @Override + public Flux requestChannel(Publisher payload) { + return assertAuthentication(authentication) + .flatMapMany(a -> super.requestChannel(payload)); + } + }; + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.requestChannel(payload)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).requestChannel(any()); + } + + @Test + public void metadataPushWhenInterceptorCompletesThenDelegateSubscribed() { + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.metadataPush(this.payload)) + .then(() -> this.voidResult.assertWasSubscribed()) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void metadataPushWhenInterceptorErrorsThenDelegateNotSubscribed() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.metadataPush(this.payload)) + .then(() -> this.voidResult.assertWasNotSubscribed()) + .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + } + + @Test + public void metadataPushWhenSecurityContextThenDelegateContext() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); + when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono()); + + RSocket assertAuthentication = new RSocketProxy(this.delegate) { + @Override + public Mono metadataPush(Payload payload) { + return assertAuthentication(authentication) + .flatMap(a -> super.metadataPush(payload)); + } + }; + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, + Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); + + StepVerifier.create(interceptor.metadataPush(this.payload)) + .verifyComplete(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.delegate).metadataPush(this.payload); + this.voidResult.assertWasSubscribed(); + } + + // multiple interceptors + + @Test + public void fireAndForgetWhenInterceptorsCompleteThenDelegateInvoked() { + when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, + dataMimeType); + + interceptor.fireAndForget(this.payload).block(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + this.voidResult.assertWasSubscribed(); + } + + + @Test + public void fireAndForgetWhenInterceptorsMutatesPayloadThenDelegateInvoked() { + when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, + dataMimeType); + + interceptor.fireAndForget(this.payload).block(); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.interceptor2).intercept(any(), any()); + verify(this.delegate).fireAndForget(eq(this.payload)); + this.voidResult.assertWasSubscribed(); + } + + @Test + public void fireAndForgetWhenInterceptor1ErrorsThenInterceptor2AndDelegateNotInvoked() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, + dataMimeType); + + assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verifyZeroInteractions(this.interceptor2); + this.voidResult.assertWasNotSubscribed(); + } + + @Test + public void fireAndForgetWhenInterceptor2ErrorsThenInterceptor2AndDelegateNotInvoked() { + RuntimeException expected = new RuntimeException("Oops"); + when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); + when(this.interceptor2.intercept(any(), any())).thenReturn(Mono.error(expected)); + + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, + dataMimeType); + + assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); + + verify(this.interceptor).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); + verify(this.interceptor2).intercept(any(), any()); + this.voidResult.assertWasNotSubscribed(); + } + + private Mono assertAuthentication(Authentication authentication) { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .doOnNext(a -> assertThat(a).isEqualTo(authentication)); + } + + private Answer withAuthenticated(Authentication authentication) { + return invocation -> { + PayloadInterceptorChain c = (PayloadInterceptorChain) invocation.getArguments()[1]; + return c.next(new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, this.metadataMimeType, + this.dataMimeType)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); + }; + } + + private static Answer> withChainNext() { + return invocation -> { + PayloadExchange exchange = (PayloadExchange) invocation.getArguments()[0]; + PayloadInterceptorChain chain = (PayloadInterceptorChain) invocation.getArguments()[1]; + return chain.next(exchange); + }; + } +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java new file mode 100644 index 0000000000..367a1b9b99 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorInterceptorTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.metadata.WellKnownMimeType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.http.MediaType; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class PayloadSocketAcceptorInterceptorTests { + @Mock + private PayloadInterceptor interceptor; + + @Mock + private SocketAcceptor socketAcceptor; + + @Mock + private ConnectionSetupPayload setupPayload; + + @Mock + private RSocket rSocket; + + @Mock + private Payload payload; + + private List interceptors; + + private PayloadSocketAcceptorInterceptor acceptorInterceptor; + + @Before + public void setup() { + this.interceptors = Arrays.asList(this.interceptor); + this.acceptorInterceptor = new PayloadSocketAcceptorInterceptor(this.interceptors); + } + + @Test + public void applyWhenDefaultMetadataMimeTypeThenDefaulted() { + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + @Test + public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() { + this.acceptorInterceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON); + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + @Test + public void acceptWhenDefaultDataMimeTypeThenDefaulted() { + this.acceptorInterceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + private PayloadExchange captureExchange() { + when(this.socketAcceptor.accept(any(), any())).thenReturn(Mono.just(this.rSocket)); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + + SocketAcceptor wrappedAcceptor = this.acceptorInterceptor.apply(this.socketAcceptor); + RSocket result = wrappedAcceptor.accept(this.setupPayload, this.rSocket).block(); + + assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class); + + when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty()); + + result.fireAndForget(this.payload).block(); + + ArgumentCaptor exchangeArg = + ArgumentCaptor.forClass(PayloadExchange.class); + verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any()); + return exchangeArg.getValue(); + } +} \ No newline at end of file diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java new file mode 100644 index 0000000000..af8154fcd6 --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/PayloadSocketAcceptorTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.metadata.WellKnownMimeType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.http.MediaType; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class PayloadSocketAcceptorTests { + + private PayloadSocketAcceptor acceptor; + + private List interceptors; + + @Mock + private SocketAcceptor delegate; + + @Mock + private PayloadInterceptor interceptor; + + @Mock + private ConnectionSetupPayload setupPayload; + + @Mock + private RSocket rSocket; + + @Mock + private Payload payload; + + @Before + public void setup() { + this.interceptors = Arrays.asList(this.interceptor); + this.acceptor = new PayloadSocketAcceptor(this.delegate, this.interceptors); + } + + @Test + public void constructorWhenNullDelegateThenException() { + this.delegate = null; + assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + } + + @Test + public void constructorWhenNullInterceptorsThenException() { + this.interceptors = null; + assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + } + + @Test + public void constructorWhenEmptyInterceptorsThenException() { + this.interceptors = Collections.emptyList(); + assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + } + + @Test + public void acceptWhenDataMimeTypeNullThenException() { + assertThatCode(() -> this.acceptor.accept(this.setupPayload, this.rSocket) + .block()).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void acceptWhenDefaultMetadataMimeTypeThenDefaulted() { + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType().toString()) + .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + @Test + public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() { + this.acceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON); + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + @Test + public void acceptWhenDefaultDataMimeTypeThenDefaulted() { + this.acceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType() + .toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + @Test + public void acceptWhenExplicitMimeTypeThenThenOverrideDefault() { + when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE); + when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); + + PayloadExchange exchange = captureExchange(); + + assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.TEXT_PLAIN); + assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); + } + + private PayloadExchange captureExchange() { + when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket)); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); + + RSocket result = this.acceptor.accept(this.setupPayload, this.rSocket).block(); + + assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class); + + when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty()); + + result.fireAndForget(this.payload).block(); + + ArgumentCaptor exchangeArg = + ArgumentCaptor.forClass(PayloadExchange.class); + verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any()); + return exchangeArg.getValue(); + } +} \ No newline at end of file diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java new file mode 100644 index 0000000000..5e214875ed --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/interceptor/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.interceptor.authorization; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.security.authorization.AuthorizationDecision; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.util.PayloadExchangeAuthorizationContext; +import org.springframework.security.rsocket.util.PayloadExchangeMatcher; +import org.springframework.security.rsocket.util.PayloadExchangeMatcherEntry; +import org.springframework.security.rsocket.util.PayloadExchangeMatchers; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class PayloadExchangeMatcherReactiveAuthorizationManagerTest { + + @Mock + private ReactiveAuthorizationManager authz; + + @Mock + private ReactiveAuthorizationManager authz2; + + @Mock + private PayloadExchange exchange; + + @Test + public void checkWhenGrantedThenGranted() { + AuthorizationDecision expected = new AuthorizationDecision(true); + when(this.authz.check(any(), any())).thenReturn(Mono.just( + expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = + PayloadExchangeMatcherReactiveAuthorizationManager.builder() + .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .build(); + + assertThat(manager.check(Mono.empty(), this.exchange).block()) + .isEqualTo(expected); + } + + @Test + public void checkWhenDeniedThenDenied() { + AuthorizationDecision expected = new AuthorizationDecision(false); + when(this.authz.check(any(), any())).thenReturn(Mono.just( + expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = + PayloadExchangeMatcherReactiveAuthorizationManager.builder() + .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .build(); + + assertThat(manager.check(Mono.empty(), this.exchange).block()) + .isEqualTo(expected); + } + + @Test + public void checkWhenFirstMatchThenSecondUsed() { + AuthorizationDecision expected = new AuthorizationDecision(true); + when(this.authz.check(any(), any())).thenReturn(Mono.just( + expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = + PayloadExchangeMatcherReactiveAuthorizationManager.builder() + .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz2)) + .build(); + + assertThat(manager.check(Mono.empty(), this.exchange).block()) + .isEqualTo(expected); + } + + @Test + public void checkWhenSecondMatchThenSecondUsed() { + AuthorizationDecision expected = new AuthorizationDecision(true); + when(this.authz2.check(any(), any())).thenReturn(Mono.just( + expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = + PayloadExchangeMatcherReactiveAuthorizationManager.builder() + .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz)) + .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz2)) + .build(); + + assertThat(manager.check(Mono.empty(), this.exchange).block()) + .isEqualTo(expected); + } +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java new file mode 100644 index 0000000000..2654d2378c --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.metadata; + +import org.junit.Test; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.util.MimeType; +import reactor.core.publisher.Mono; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Rob Winch + */ +public class BasicAuthenticationDecoderTests { + @Test + public void basicAuthenticationWhenEncodedThenDecodes() { + BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder(); + BasicAuthenticationDecoder decoder = new BasicAuthenticationDecoder(); + UsernamePasswordMetadata expectedCredentials = + new UsernamePasswordMetadata("rob", "password"); + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + ResolvableType elementType = ResolvableType + .forClass(UsernamePasswordMetadata.class); + MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE; + Map hints = null; + + DataBuffer dataBuffer = encoder.encodeValue(expectedCredentials, factory, + elementType, mimeType, hints); + UsernamePasswordMetadata actualCredentials = decoder + .decodeToMono(Mono.just(dataBuffer), elementType, mimeType, hints).block(); + + assertThat(actualCredentials).isEqualToComparingFieldByField(expectedCredentials); + } + +} \ No newline at end of file diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java new file mode 100644 index 0000000000..8c8c70ac1a --- /dev/null +++ b/rsocket/src/test/java/org/springframework/security/rsocket/util/RoutePayloadExchangeMatcherTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.rsocket.util; + +import io.rsocket.Payload; +import io.rsocket.metadata.WellKnownMimeType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.http.MediaType; +import org.springframework.messaging.rsocket.MetadataExtractor; +import org.springframework.security.rsocket.interceptor.DefaultPayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadExchange; +import org.springframework.security.rsocket.interceptor.PayloadExchangeType; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.RouteMatcher; + +import java.util.Collections; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class RoutePayloadExchangeMatcherTests { + static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + @Mock + private MetadataExtractor metadataExtractor; + + @Mock + private RouteMatcher routeMatcher; + + private PayloadExchange exchange; + + @Mock + private Payload payload; + + @Mock + private RouteMatcher.Route route; + + private String pattern; + + private RoutePayloadExchangeMatcher matcher; + + @Before + public void setup() { + this.pattern = "a.b"; + this.matcher = new RoutePayloadExchangeMatcher(this.metadataExtractor, this.routeMatcher, this.pattern); + this.exchange = new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, COMPOSITE_METADATA, + MediaType.APPLICATION_JSON); + } + + @Test + public void matchesWhenNoRouteThenNotMatch() { + when(this.metadataExtractor.extract(any(), any())) + .thenReturn(Collections.emptyMap()); + PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); + assertThat(result.isMatch()).isFalse(); + } + + @Test + public void matchesWhenNotMatchThenNotMatch() { + String route = "route"; + when(this.metadataExtractor.extract(any(), any())) + .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); + assertThat(result.isMatch()).isFalse(); + } + + @Test + public void matchesWhenMatchAndNoVariablesThenMatch() { + String route = "route"; + when(this.metadataExtractor.extract(any(), any())) + .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + when(this.routeMatcher.parseRoute(any())).thenReturn(this.route); + when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(Collections.emptyMap()); + PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); + assertThat(result.isMatch()).isTrue(); + } + + @Test + public void matchesWhenMatchAndVariablesThenMatchAndVariables() { + String route = "route"; + Map variables = Collections.singletonMap("a", "b"); + when(this.metadataExtractor.extract(any(), any())) + .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + when(this.routeMatcher.parseRoute(any())).thenReturn(this.route); + when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(variables); + PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); + assertThat(result.isMatch()).isTrue(); + assertThat(result.getVariables()).containsAllEntriesOf(variables); + } +} \ No newline at end of file