Issue #8216 - improve testing for end_session_endpoint

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2022-07-12 15:17:08 +10:00
parent 26732c90a0
commit 90fe5621f9
4 changed files with 86 additions and 25 deletions

View File

@ -260,6 +260,7 @@ public class OpenIdAuthenticator extends LoginAuthenticator
session.removeAttribute(SessionAuthentication.__J_AUTHENTICATED);
session.removeAttribute(CLAIMS);
session.removeAttribute(RESPONSE);
session.removeAttribute(ISSUER);
}
}
@ -269,30 +270,33 @@ public class OpenIdAuthenticator extends LoginAuthenticator
{
Request baseRequest = Objects.requireNonNull(Request.getBaseRequest(request));
Response baseResponse = baseRequest.getResponse();
String endSessionEndpoint = _openIdConfiguration.getEndSessionEndpoint();
if (endSessionEndpoint == null)
return;
StringBuilder redirectUri = new StringBuilder(128);
URIUtil.appendSchemeHostPort(redirectUri, request.getScheme(), request.getServerName(), request.getServerPort());
redirectUri.append(baseRequest.getContextPath());
redirectUri.append(_logoutRedirectPath);
String endSessionEndpoint = _openIdConfiguration.getEndSessionEndpoint();
HttpSession session = baseRequest.getSession(false);
if (session == null)
if (endSessionEndpoint == null || session == null)
{
baseResponse.sendRedirect(redirectUri.toString(), true);
return;
}
Object openIdResponse = session.getAttribute(OpenIdAuthenticator.RESPONSE);
if (openIdResponse instanceof Map)
if (!(openIdResponse instanceof Map))
{
@SuppressWarnings("rawtypes")
String idToken = (String)((Map)openIdResponse).get("id_token");
baseResponse.sendRedirect(endSessionEndpoint +
"?id_token_hint=" + UrlEncoded.encodeString(idToken, StandardCharsets.UTF_8) +
"&post_logout_redirect_uri=" + UrlEncoded.encodeString(redirectUri.toString(), StandardCharsets.UTF_8),
true);
baseResponse.sendRedirect(redirectUri.toString(), true);
return;
}
@SuppressWarnings("rawtypes")
String idToken = (String)((Map)openIdResponse).get("id_token");
baseResponse.sendRedirect(endSessionEndpoint +
"?id_token_hint=" + UrlEncoded.encodeString(idToken, StandardCharsets.UTF_8) +
"&post_logout_redirect_uri=" + UrlEncoded.encodeString(redirectUri.toString(), StandardCharsets.UTF_8),
true);
}
catch (Throwable t)
{

View File

@ -42,6 +42,7 @@ public class OpenIdConfiguration extends ContainerLifeCycle
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTHORIZATION_ENDPOINT = "authorization_endpoint";
private static final String TOKEN_ENDPOINT = "token_endpoint";
private static final String END_SESSION_ENDPOINT = "end_session_endpoint";
private static final String ISSUER = "issuer";
private final HttpClient httpClient;
@ -165,9 +166,9 @@ public class OpenIdConfiguration extends ContainerLifeCycle
if (tokenEndpoint == null)
throw new IllegalStateException(TOKEN_ENDPOINT);
endSessionEndpoint = (String)discoveryDocument.get("end_session_endpoint");
// End session endpoint is optional.
if (endSessionEndpoint == null)
throw new IllegalArgumentException("end_session_endpoint");
endSessionEndpoint = (String)discoveryDocument.get(END_SESSION_ENDPOINT);
// We are lenient and not throw here as some major OIDC providers do not conform to this.
if (!Objects.equals(discoveryDocument.get(ISSUER), issuer))

View File

@ -16,6 +16,7 @@ package org.eclipse.jetty.security.openid;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -35,6 +36,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
@ -155,6 +157,11 @@ public class OpenIdAuthenticationTest
assertThat(response.getStatus(), is(HttpStatus.OK_200));
content = response.getContentAsString();
assertThat(content, containsString("not authenticated"));
// Test that the user was logged out successfully on the openid provider.
assertThat(openIdProvider.getLoggedInUsers().getCurrent(), equalTo(0L));
assertThat(openIdProvider.getLoggedInUsers().getMax(), equalTo(1L));
assertThat(openIdProvider.getLoggedInUsers().getTotal(), equalTo(1L));
}
public static class LoginPage extends HttpServlet
@ -171,10 +178,9 @@ public class OpenIdAuthenticationTest
public static class LogoutPage extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException
{
request.getSession().invalidate();
response.sendRedirect("/");
request.logout();
}
}

View File

@ -38,6 +38,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -48,6 +49,7 @@ public class OpenIdProvider extends ContainerLifeCycle
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
@ -58,6 +60,7 @@ public class OpenIdProvider extends ContainerLifeCycle
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
public static void main(String[] args) throws Exception
{
@ -91,9 +94,10 @@ public class OpenIdProvider extends ContainerLifeCycle
ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new OpenIdConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new OpenIdAuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new OpenIdTokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);
addBean(server);
@ -112,6 +116,11 @@ public class OpenIdProvider extends ContainerLifeCycle
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
@ -144,7 +153,7 @@ public class OpenIdProvider extends ContainerLifeCycle
redirectUris.add(uri);
}
public class OpenIdAuthEndpoint extends HttpServlet
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
@ -252,7 +261,7 @@ public class OpenIdProvider extends ContainerLifeCycle
}
}
public class OpenIdTokenEndpoint extends HttpServlet
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
@ -285,12 +294,44 @@ public class OpenIdProvider extends ContainerLifeCycle
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}
public class OpenIdConfigServlet extends HttpServlet
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}
String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no post_logout_redirect_uri");
return;
}
loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}
private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
@ -299,6 +340,7 @@ public class OpenIdProvider extends ContainerLifeCycle
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
resp.getWriter().write(discoveryDocument);
@ -336,5 +378,13 @@ public class OpenIdProvider extends ContainerLifeCycle
long expiry = System.currentTimeMillis() + Duration.ofMinutes(1).toMillis();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiry);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
}
}