diff --git a/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java b/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java new file mode 100644 index 0000000000..cda41044ca --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/web/server/HttpMessageConverters.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.web.server; + +import org.springframework.http.converter.GenericHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.json.GsonHttpMessageConverter; +import org.springframework.http.converter.json.JsonbHttpMessageConverter; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.util.ClassUtils; + +/** + * Utility methods for {@link HttpMessageConverter}'s. + * + * @author Joe Grandja + * @author luamas + * @since 5.1 + */ +final class HttpMessageConverters { + + private static final boolean jackson2Present; + + private static final boolean gsonPresent; + + private static final boolean jsonbPresent; + + static { + ClassLoader classLoader = HttpMessageConverters.class.getClassLoader(); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) + && ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); + gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader); + jsonbPresent = ClassUtils.isPresent("jakarta.json.bind.Jsonb", classLoader); + } + + private HttpMessageConverters() { + } + + static GenericHttpMessageConverter getJsonMessageConverter() { + if (jackson2Present) { + return new MappingJackson2HttpMessageConverter(); + } + if (gsonPresent) { + return new GsonHttpMessageConverter(); + } + if (jsonbPresent) { + return new JsonbHttpMessageConverter(); + } + return null; + } + +} diff --git a/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java b/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java new file mode 100644 index 0000000000..784782344c --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/web/server/OAuth2ErrorEncoder.java @@ -0,0 +1,95 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.web.server; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.jetbrains.annotations.NotNull; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageEncoder; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.util.MimeType; + +class OAuth2ErrorEncoder implements HttpMessageEncoder { + + private final HttpMessageConverter messageConverter = HttpMessageConverters.getJsonMessageConverter(); + + @NotNull + @Override + public List getStreamingMediaTypes() { + return List.of(); + } + + @Override + public boolean canEncode(ResolvableType elementType, MimeType mimeType) { + return getEncodableMimeTypes().contains(mimeType); + } + + @NotNull + @Override + public Flux encode(Publisher error, DataBufferFactory bufferFactory, + ResolvableType elementType, MimeType mimeType, Map hints) { + return Mono.from(error).flatMap((data) -> { + ByteArrayHttpOutputMessage bytes = new ByteArrayHttpOutputMessage(); + try { + this.messageConverter.write(data, MediaType.APPLICATION_JSON, bytes); + return Mono.just(bytes.getBody().toByteArray()); + } + catch (IOException ex) { + return Mono.error(ex); + } + }).map(bufferFactory::wrap).flux(); + } + + @NotNull + @Override + public List getEncodableMimeTypes() { + return List.of(MediaType.APPLICATION_JSON); + } + + private static class ByteArrayHttpOutputMessage implements HttpOutputMessage { + + private final ByteArrayOutputStream body = new ByteArrayOutputStream(); + + @NotNull + @Override + public ByteArrayOutputStream getBody() { + return this.body; + } + + @NotNull + @Override + public HttpHeaders getHeaders() { + return new HttpHeaders(); + } + + } + +} diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java index 74f5f32e68..8f1788c498 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java +++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java @@ -16,16 +16,17 @@ package org.springframework.security.config.web.server; -import java.nio.charset.StandardCharsets; +import java.util.Collections; import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.core.ResolvableType; +import org.springframework.http.MediaType; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.ReactiveAuthenticationManager; @@ -62,6 +63,9 @@ class OidcBackChannelLogoutWebFilter implements WebFilter { private ServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler(); + private final HttpMessageWriter errorHttpMessageConverter = new EncoderHttpMessageWriter<>( + new OAuth2ErrorEncoder()); + /** * Construct an {@link OidcBackChannelLogoutWebFilter} * @param authenticationConverter the {@link AuthenticationConverter} for deriving @@ -84,7 +88,7 @@ class OidcBackChannelLogoutWebFilter implements WebFilter { if (ex instanceof AuthenticationServiceException) { return Mono.error(ex); } - return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty()); + return handleAuthenticationFailure(exchange, ex).then(Mono.empty()); }) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .flatMap(this.authenticationManager::authenticate) @@ -93,7 +97,7 @@ class OidcBackChannelLogoutWebFilter implements WebFilter { if (ex instanceof AuthenticationServiceException) { return Mono.error(ex); } - return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty()); + return handleAuthenticationFailure(exchange, ex).then(Mono.empty()); }) .flatMap((authentication) -> { WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); @@ -101,19 +105,12 @@ class OidcBackChannelLogoutWebFilter implements WebFilter { }); } - private Mono handleAuthenticationFailure(ServerHttpResponse response, Exception ex) { + private Mono handleAuthenticationFailure(ServerWebExchange exchange, Exception ex) { this.logger.debug("Failed to process OIDC Back-Channel Logout", ex); - response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST); - OAuth2Error error = oauth2Error(ex); - byte[] bytes = String.format(""" - { - "error_code": "%s", - "error_description": "%s", - "error_uri: "%s" - } - """, error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = response.bufferFactory().wrap(bytes); - return response.writeWith(Flux.just(buffer)); + exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST); + return this.errorHttpMessageConverter.write(Mono.just(oauth2Error(ex)), ResolvableType.forClass(Object.class), + ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(), + exchange.getResponse(), Collections.emptyMap()); } private OAuth2Error oauth2Error(Exception ex) { diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java index 5312a6da7c..c0c1e73bc6 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java +++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java @@ -16,8 +16,8 @@ package org.springframework.security.config.web.server; -import java.nio.charset.StandardCharsets; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -25,14 +25,15 @@ import java.util.concurrent.atomic.AtomicInteger; import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; import org.springframework.security.oauth2.client.oidc.server.session.InMemoryReactiveOidcSessionRegistry; @@ -44,6 +45,7 @@ import org.springframework.security.web.server.WebFilterExchange; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; import org.springframework.util.Assert; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; @@ -63,6 +65,9 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler { private ReactiveOidcSessionRegistry sessionRegistry = new InMemoryReactiveOidcSessionRegistry(); + private final HttpMessageWriter errorHttpMessageConverter = new EncoderHttpMessageWriter<>( + new OAuth2ErrorEncoder()); + private WebClient web = WebClient.create(); private String logoutUri = "{baseScheme}://localhost{basePort}/logout"; @@ -97,7 +102,7 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler { totalCount.intValue())); } if (!list.isEmpty()) { - return handleLogoutFailure(exchange.getExchange().getResponse(), oauth2Error(list)); + return handleLogoutFailure(exchange.getExchange(), oauth2Error(list)); } else { return Mono.empty(); @@ -148,17 +153,11 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler { "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"); } - private Mono handleLogoutFailure(ServerHttpResponse response, OAuth2Error error) { - response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST); - byte[] bytes = String.format(""" - { - "error_code": "%s", - "error_description": "%s", - "error_uri: "%s" - } - """, error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = response.bufferFactory().wrap(bytes); - return response.writeWith(Flux.just(buffer)); + private Mono handleLogoutFailure(ServerWebExchange exchange, OAuth2Error error) { + exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST); + return this.errorHttpMessageConverter.write(Mono.just(error), ResolvableType.forClass(Object.class), + ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(), + exchange.getResponse(), Collections.emptyMap()); } /** diff --git a/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java index dd4cf77eb8..461f7ac47e 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java @@ -50,6 +50,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.annotation.Order; import org.springframework.http.ResponseCookie; import org.springframework.http.client.reactive.ClientHttpConnector; @@ -99,6 +100,7 @@ import org.springframework.web.server.WebSession; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; @@ -197,7 +199,10 @@ public class OidcLogoutSpecTests { .body(BodyInserters.fromFormData("logout_token", "invalid")) .exchange() .expectStatus() - .isBadRequest(); + .isBadRequest() + .expectBody(new ParameterizedTypeReference>() { + }) + .value(hasValue("invalid_request")); this.test.get().uri("/token/logout").cookie("SESSION", session).exchange().expectStatus().isOk(); } @@ -264,9 +269,10 @@ public class OidcLogoutSpecTests { .exchange() .expectStatus() .isBadRequest() - .expectBody(String.class) - .value(containsString("partial_logout")) - .value(containsString("not all sessions were terminated")); + .expectBody(new ParameterizedTypeReference>() { + }) + .value(hasValue("partial_logout")) + .value(hasValue(containsString("not all sessions were terminated"))); this.test.get().uri("/token/logout").cookie("SESSION", one).exchange().expectStatus().isOk(); }