reduce duplication for OpenIdProvider testing class

Signed-off-by: Lachlan Roberts <lachlan.p.roberts@gmail.com>
This commit is contained in:
Lachlan Roberts 2024-08-26 11:48:10 +10:00
parent 9756387020
commit 063e839962
No known key found for this signature in database
GPG Key ID: 5663FB7A8FF7E348
26 changed files with 493 additions and 2269 deletions

View File

@ -52,6 +52,11 @@
<artifactId>jetty-ee8-servlet</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.toolchain</groupId>
<artifactId>jetty-test-helper</artifactId>

View File

@ -55,6 +55,11 @@
<artifactId>jetty-ee9-servlet</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.toolchain</groupId>
<artifactId>jetty-test-helper</artifactId>

View File

@ -1,53 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee9.security.openid;
import java.util.Base64;
/**
* A basic JWT encoder for testing purposes.
*/
public class JwtEncoder
{
private static final Base64.Encoder ENCODER = Base64.getUrlEncoder();
private static final String DEFAULT_HEADER = "{\"INFO\": \"this is not used or checked in our implementation\"}";
private static final String DEFAULT_SIGNATURE = "we do not validate signature as we use the authorization code flow";
public static String encode(String idToken)
{
return stripPadding(ENCODER.encodeToString(DEFAULT_HEADER.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(idToken.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(DEFAULT_SIGNATURE.getBytes()));
}
private static String stripPadding(String paddedBase64)
{
return paddedBase64.split("=")[0];
}
/**
* Create a basic JWT for testing using argument supplied attributes.
*/
public static String createIdToken(String provider, String clientId, String subject, String name, long expiry)
{
return "{" +
"\"iss\": \"" + provider + "\"," +
"\"sub\": \"" + subject + "\"," +
"\"aud\": \"" + clientId + "\"," +
"\"exp\": " + expiry + "," +
"\"name\": \"" + name + "\"," +
"\"email\": \"" + name + "@example.com" + "\"" +
"}";
}
}

View File

@ -43,6 +43,7 @@ import org.eclipse.jetty.security.openid.OpenIdConfiguration;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.session.FileSessionDataStoreFactory;
import org.eclipse.jetty.tests.OpenIdProvider;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.security.Password;

View File

@ -1,402 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee9.security.openid;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.ee9.servlet.ServletContextHandler;
import org.eclipse.jetty.ee9.servlet.ServletHolder;
import org.eclipse.jetty.security.openid.OpenIdConfiguration;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);
addBean(server);
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public OpenIdConfiguration getOpenIdConfiguration()
{
String provider = getProvider();
String authEndpoint = provider + AUTH_PATH;
String tokenEndpoint = provider + TOKEN_PATH;
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
if (!clientId.equals(req.getParameter("client_id")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid client_id");
return;
}
String redirectUri = req.getParameter("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
LOG.warn("invalid redirectUri {}", redirectUri);
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String scopeString = req.getParameter("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(scopeString.split(" "));
if (!scopes.contains("openid"))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no openid scope");
return;
}
if (!"code".equals(req.getParameter("response_type")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "response_type must be code");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
if (preAuthedUser == null)
{
PrintWriter writer = resp.getWriter();
resp.setContentType("text/html");
writer.println("<h2>Login to OpenID Connect Provider</h2>");
writer.println("<form action=\"" + AUTH_PATH + "\" method=\"post\">");
writer.println("<input type=\"text\" autocomplete=\"off\" placeholder=\"Username\" name=\"username\" required>");
writer.println("<input type=\"hidden\" name=\"redirectUri\" value=\"" + redirectUri + "\">");
writer.println("<input type=\"hidden\" name=\"state\" value=\"" + state + "\">");
writer.println("<input type=\"submit\">");
writer.println("</form>");
}
else
{
redirectUser(resp, preAuthedUser, redirectUri, state);
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String redirectUri = req.getParameter("redirectUri");
if (!redirectUris.contains(redirectUri))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
String username = req.getParameter("username");
if (username == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username");
return;
}
User user = new User(username);
redirectUser(resp, user, redirectUri, state);
}
public void redirectUser(HttpServletResponse response, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
response.sendRedirect(response.encodeRedirectURL(redirectUri));
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
}
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
String code = req.getParameter("code");
if (!clientId.equals(req.getParameter("client_id")) ||
!clientSecret.equals(req.getParameter("client_secret")) ||
!redirectUris.contains(req.getParameter("redirect_uri")) ||
!"authorization_code".equals(req.getParameter("grant_type")) ||
code == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "bad auth request");
return;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid auth code");
return;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String response = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}
String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().println("logout success on end_session_endpoint");
return;
}
loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}
private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
resp.getWriter().write(discoveryDocument);
}
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}

View File

@ -47,6 +47,11 @@
<artifactId>jetty-slf4j-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.toolchain</groupId>
<artifactId>jetty-test-helper</artifactId>

View File

@ -16,6 +16,7 @@ package org.eclipse.jetty.security.openid;
import java.util.Map;
import java.util.stream.Stream;
import org.eclipse.jetty.tests.JwtEncoder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;

View File

@ -43,6 +43,7 @@ import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.server.handler.ContextHandlerCollection;
import org.eclipse.jetty.session.FileSessionDataStoreFactory;
import org.eclipse.jetty.session.SessionHandler;
import org.eclipse.jetty.tests.OpenIdProvider;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;

View File

@ -1,445 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.security.openid;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.server.handler.ContextHandlerCollection;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Fields;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ContextHandlerCollection contexts = new ContextHandlerCollection();
contexts.addHandler(new ConfigServlet(CONFIG_PATH));
contexts.addHandler(new AuthEndpoint(AUTH_PATH));
contexts.addHandler(new TokenEndpoint(TOKEN_PATH));
contexts.addHandler(new EndSessionEndpoint(END_SESSION_PATH));
server.setHandler(contexts);
addBean(server);
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public OpenIdConfiguration getOpenIdConfiguration()
{
String provider = getProvider();
String authEndpoint = provider + AUTH_PATH;
String tokenEndpoint = provider + TOKEN_PATH;
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class AuthEndpoint extends ContextHandler
{
public AuthEndpoint(String contextPath)
{
super(contextPath);
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
switch (request.getMethod())
{
case "GET":
doGet(request, response, callback);
break;
case "POST":
doPost(request, response, callback);
break;
default:
throw new BadMessageException("Unsupported HTTP Method");
}
return true;
}
protected void doGet(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
if (!clientId.equals(parameters.getValue("client_id")))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid client_id");
return;
}
String redirectUri = parameters.getValue("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid redirect_uri");
return;
}
String scopeString = parameters.getValue("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(StringUtil.csvSplit(scopeString));
if (!scopes.contains("openid"))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no openid scope");
return;
}
if (!"code".equals(parameters.getValue("response_type")))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "response_type must be code");
return;
}
String state = parameters.getValue("state");
if (state == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no state param");
return;
}
if (preAuthedUser == null)
{
response.getHeaders().add(HttpHeader.CONTENT_TYPE, "text/html");
String content =
"<h2>Login to OpenID Connect Provider</h2>" +
"<form action=\"" + AUTH_PATH + "\" method=\"post\">" +
"<input type=\"text\" autocomplete=\"off\" placeholder=\"Username\" name=\"username\" required>" +
"<input type=\"hidden\" name=\"redirectUri\" value=\"" + redirectUri + "\">" +
"<input type=\"hidden\" name=\"state\" value=\"" + state + "\">" +
"<input type=\"submit\">" +
"</form>";
response.write(true, BufferUtil.toBuffer(content), callback);
}
else
{
redirectUser(request, response, callback, preAuthedUser, redirectUri, state);
}
}
protected void doPost(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
String redirectUri = parameters.getValue("redirectUri");
if (!redirectUris.contains(redirectUri))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid redirect_uri");
return;
}
String state = parameters.getValue("state");
if (state == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no state param");
return;
}
String username = parameters.getValue("username");
if (username == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no username");
return;
}
User user = new User(username);
redirectUser(request, response, callback, user, redirectUri, state);
}
public void redirectUser(Request request, Response response, Callback callback, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
Response.sendRedirect(request, response, callback, redirectUri);
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
}
private class TokenEndpoint extends ContextHandler
{
public TokenEndpoint(String contextPath)
{
super(contextPath);
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
String code = parameters.getValue("code");
if (!clientId.equals(parameters.getValue("client_id")) ||
!clientSecret.equals(parameters.getValue("client_secret")) ||
!redirectUris.contains(parameters.getValue("redirect_uri")) ||
!"authorization_code".equals(parameters.getValue("grant_type")) ||
code == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "bad auth request");
return true;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid auth code");
return true;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String content = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
response.getHeaders().put(HttpHeader.CONTENT_TYPE, "text/plain");
response.write(true, BufferUtil.toBuffer(content), callback);
return true;
}
}
private class EndSessionEndpoint extends ContextHandler
{
public EndSessionEndpoint(String contextPath)
{
super(contextPath);
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
String idToken = parameters.getValue("id_token_hint");
if (idToken == null)
{
Response.writeError(request, response, callback, HttpStatus.BAD_REQUEST_400, "no id_token_hint");
return true;
}
String logoutRedirect = parameters.getValue("post_logout_redirect_uri");
if (logoutRedirect == null)
{
response.setStatus(HttpStatus.OK_200);
response.write(true, BufferUtil.toBuffer("logout success on end_session_endpoint"), callback);
return true;
}
loggedInUsers.decrement();
response.getHeaders().put(HttpHeader.CONTENT_TYPE, "text/plain");
Response.sendRedirect(request, response, callback, logoutRedirect);
return true;
}
}
private class ConfigServlet extends ContextHandler
{
public ConfigServlet(String contextPath)
{
super(contextPath);
}
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
response.write(true, BufferUtil.toBuffer(discoveryDocument), callback);
return true;
}
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}

View File

@ -1060,6 +1060,11 @@
<artifactId>jetty-test-session-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-testers</artifactId>

View File

@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>tests</artifactId>
<version>12.1.0-SNAPSHOT</version>
</parent>
<artifactId>jetty-test-util</artifactId>
<packaging>jar</packaging>
<name>Tests :: Test Utilities</name>
<properties>
<bundle-symbolic-name>${project.groupId}.testers</bundle-symbolic-name>
</properties>
<dependencies>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-util</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -11,7 +11,7 @@
// ========================================================================
//
package org.eclipse.jetty.security.openid;
package org.eclipse.jetty.tests;
import java.util.Base64;

View File

@ -0,0 +1,421 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.tests;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpMethod;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Fields;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider()
{
this("clientId" + StringUtil.randomAlphaNumeric(4), StringUtil.randomAlphaNumeric(10));
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
server.setHandler(new OpenIdProviderHandler());
addBean(server);
}
public String getClientId()
{
return clientId;
}
public String getClientSecret()
{
return clientSecret;
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class OpenIdProviderHandler extends Handler.Abstract
{
@Override
public boolean handle(Request request, Response response, Callback callback) throws Exception
{
String pathInContext = Request.getPathInContext(request);
switch (pathInContext)
{
case CONFIG_PATH -> doGetConfigServlet(request, response, callback);
case AUTH_PATH -> doAuthEndpoint(request, response, callback);
case TOKEN_PATH -> doTokenEndpoint(request, response, callback);
case END_SESSION_PATH -> doEndSessionEndpoint(request, response, callback);
default -> Response.writeError(request, response, callback, HttpStatus.NOT_FOUND_404);
}
return true;
}
}
protected void doAuthEndpoint(Request request, Response response, Callback callback) throws Exception
{
String method = request.getMethod();
switch (method)
{
case "GET" -> doGetAuthEndpoint(request, response, callback);
case "POST" -> doPostAuthEndpoint(request, response, callback);
default -> throw new BadMessageException("Unsupported HTTP method: " + method);
}
}
protected void doGetAuthEndpoint(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
if (!clientId.equals(parameters.getValue("client_id")))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid client_id");
return;
}
String redirectUri = parameters.getValue("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
LOG.warn("invalid redirectUri {}", redirectUri);
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid redirect_uri");
return;
}
String scopeString = parameters.getValue("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(scopeString.split(" "));
if (!scopes.contains("openid"))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no openid scope");
return;
}
if (!"code".equals(parameters.getValue("response_type")))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "response_type must be code");
return;
}
String state = parameters.getValue("state");
if (state == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no state param");
return;
}
if (preAuthedUser == null)
{
String responseContent = String.format("""
<h2>Login to OpenID Connect Provider</h2>
<form action="%s" method="post">
<input type="text" autocomplete="off" placeholder="Username" name="username" required>
<input type="hidden" name="redirectUri" value="%s">
<input type="hidden" name="state" value="%s">
<input type="submit">
</form>
""", AUTH_PATH, redirectUri, state);
response.getHeaders().put(HttpHeader.CONTENT_TYPE, "text/html");
response.write(true, BufferUtil.toBuffer(responseContent), callback);
}
else
{
redirectUser(request, response, callback, preAuthedUser, redirectUri, state);
}
}
protected void doPostAuthEndpoint(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
String redirectUri = parameters.getValue("redirectUri");
if (!redirectUris.contains(redirectUri))
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid redirect_uri");
return;
}
String state = parameters.getValue("state");
if (state == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no state param");
return;
}
String username = parameters.getValue("username");
if (username == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "no username");
return;
}
User user = new User(username);
redirectUser(request, response, callback, user, redirectUri, state);
}
public void redirectUser(Request request, Response response, Callback callback, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
Response.sendRedirect(request, response, callback, redirectUri);
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
protected void doTokenEndpoint(Request request, Response response, Callback callback) throws Exception
{
if (!HttpMethod.POST.is(request.getMethod()))
throw new BadMessageException("Unsupported HTTP method for token Endpoint: " + request.getMethod());
Fields parameters = Request.getParameters(request);
String code = parameters.getValue("code");
if (!clientId.equals(parameters.getValue("client_id")) ||
!clientSecret.equals(parameters.getValue("client_secret")) ||
!redirectUris.contains(parameters.getValue("redirect_uri")) ||
!"authorization_code".equals(parameters.getValue("grant_type")) ||
code == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "bad auth request");
return;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403, "invalid auth code");
return;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String responseContent = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
response.getHeaders().put(HttpHeader.CONTENT_TYPE, "text/plain");
response.write(true, BufferUtil.toBuffer(responseContent), callback);
}
protected void doEndSessionEndpoint(Request request, Response response, Callback callback) throws Exception
{
Fields parameters = Request.getParameters(request);
String idToken = parameters.getValue("id_token_hint");
if (idToken == null)
{
Response.writeError(request, response, callback, HttpStatus.BAD_REQUEST_400, "no id_token_hint");
return;
}
String logoutRedirect = parameters.getValue("post_logout_redirect_uri");
if (logoutRedirect == null)
{
response.setStatus(HttpStatus.OK_200);
response.write(true, BufferUtil.toBuffer("logout success on end_session_endpoint"), callback);
return;
}
loggedInUsers.decrement();
Response.sendRedirect(request, response, callback, logoutRedirect);
}
protected void doGetConfigServlet(Request request, Response response, Callback callback) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
response.write(true, BufferUtil.toBuffer(discoveryDocument), callback);
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}

View File

@ -15,6 +15,7 @@
<module>jetty-jmh</module>
<module>jetty-test-multipart</module>
<module>jetty-test-session-common</module>
<module>jetty-test-util</module>
<module>test-cross-context-dispatch</module>
<module>test-distribution</module>
<module>test-integration</module>

View File

@ -60,6 +60,11 @@
<type>war</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-testers</artifactId>

View File

@ -17,8 +17,8 @@ import java.nio.file.Path;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.client.ContentResponse;
import org.eclipse.jetty.ee10.tests.distribution.openid.OpenIdProvider;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.tests.OpenIdProvider;
import org.eclipse.jetty.tests.distribution.AbstractJettyHomeTest;
import org.eclipse.jetty.tests.testers.JettyHomeTester;
import org.eclipse.jetty.tests.testers.Tester;

View File

@ -1,53 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee10.tests.distribution.openid;
import java.util.Base64;
/**
* A basic JWT encoder for testing purposes.
*/
public class JwtEncoder
{
private static final Base64.Encoder ENCODER = Base64.getUrlEncoder();
private static final String DEFAULT_HEADER = "{\"INFO\": \"this is not used or checked in our implementation\"}";
private static final String DEFAULT_SIGNATURE = "we do not validate signature as we use the authorization code flow";
public static String encode(String idToken)
{
return stripPadding(ENCODER.encodeToString(DEFAULT_HEADER.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(idToken.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(DEFAULT_SIGNATURE.getBytes()));
}
private static String stripPadding(String paddedBase64)
{
return paddedBase64.split("=")[0];
}
/**
* Create a basic JWT for testing using argument supplied attributes.
*/
public static String createIdToken(String provider, String clientId, String subject, String name, long expiry)
{
return "{" +
"\"iss\": \"" + provider + "\"," +
"\"sub\": \"" + subject + "\"," +
"\"aud\": \"" + clientId + "\"," +
"\"exp\": " + expiry + "," +
"\"name\": \"" + name + "\"," +
"\"email\": \"" + name + "@example.com" + "\"" +
"}";
}
}

View File

@ -1,402 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee10.tests.distribution.openid;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHolder;
import org.eclipse.jetty.security.openid.OpenIdConfiguration;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);
addBean(server);
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public OpenIdConfiguration getOpenIdConfiguration()
{
String provider = getProvider();
String authEndpoint = provider + AUTH_PATH;
String tokenEndpoint = provider + TOKEN_PATH;
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
if (!clientId.equals(req.getParameter("client_id")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid client_id");
return;
}
String redirectUri = req.getParameter("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
LOG.warn("invalid redirectUri {}", redirectUri);
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String scopeString = req.getParameter("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(scopeString.split(" "));
if (!scopes.contains("openid"))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no openid scope");
return;
}
if (!"code".equals(req.getParameter("response_type")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "response_type must be code");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
if (preAuthedUser == null)
{
PrintWriter writer = resp.getWriter();
resp.setContentType("text/html");
writer.println("<h2>Login to OpenID Connect Provider</h2>");
writer.println("<form action=\"" + AUTH_PATH + "\" method=\"post\">");
writer.println("<input type=\"text\" autocomplete=\"off\" placeholder=\"Username\" name=\"username\" required>");
writer.println("<input type=\"hidden\" name=\"redirectUri\" value=\"" + redirectUri + "\">");
writer.println("<input type=\"hidden\" name=\"state\" value=\"" + state + "\">");
writer.println("<input type=\"submit\">");
writer.println("</form>");
}
else
{
redirectUser(resp, preAuthedUser, redirectUri, state);
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String redirectUri = req.getParameter("redirectUri");
if (!redirectUris.contains(redirectUri))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
String username = req.getParameter("username");
if (username == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username");
return;
}
User user = new User(username);
redirectUser(resp, user, redirectUri, state);
}
public void redirectUser(HttpServletResponse response, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
response.sendRedirect(response.encodeRedirectURL(redirectUri));
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
}
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
String code = req.getParameter("code");
if (!clientId.equals(req.getParameter("client_id")) ||
!clientSecret.equals(req.getParameter("client_secret")) ||
!redirectUris.contains(req.getParameter("redirect_uri")) ||
!"authorization_code".equals(req.getParameter("grant_type")) ||
code == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "bad auth request");
return;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid auth code");
return;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String response = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}
String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().println("logout success on end_session_endpoint");
return;
}
loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}
private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
resp.getWriter().write(discoveryDocument);
}
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}

View File

@ -55,6 +55,11 @@
<type>war</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-testers</artifactId>

View File

@ -17,8 +17,8 @@ import java.nio.file.Path;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.client.ContentResponse;
import org.eclipse.jetty.ee11.tests.distribution.openid.OpenIdProvider;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.tests.OpenIdProvider;
import org.eclipse.jetty.tests.distribution.AbstractJettyHomeTest;
import org.eclipse.jetty.tests.testers.JettyHomeTester;
import org.eclipse.jetty.tests.testers.Tester;

View File

@ -1,53 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee11.tests.distribution.openid;
import java.util.Base64;
/**
* A basic JWT encoder for testing purposes.
*/
public class JwtEncoder
{
private static final Base64.Encoder ENCODER = Base64.getUrlEncoder();
private static final String DEFAULT_HEADER = "{\"INFO\": \"this is not used or checked in our implementation\"}";
private static final String DEFAULT_SIGNATURE = "we do not validate signature as we use the authorization code flow";
public static String encode(String idToken)
{
return stripPadding(ENCODER.encodeToString(DEFAULT_HEADER.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(idToken.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(DEFAULT_SIGNATURE.getBytes()));
}
private static String stripPadding(String paddedBase64)
{
return paddedBase64.split("=")[0];
}
/**
* Create a basic JWT for testing using argument supplied attributes.
*/
public static String createIdToken(String provider, String clientId, String subject, String name, long expiry)
{
return "{" +
"\"iss\": \"" + provider + "\"," +
"\"sub\": \"" + subject + "\"," +
"\"aud\": \"" + clientId + "\"," +
"\"exp\": " + expiry + "," +
"\"name\": \"" + name + "\"," +
"\"email\": \"" + name + "@example.com" + "\"" +
"}";
}
}

View File

@ -1,402 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee11.tests.distribution.openid;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.ee11.servlet.ServletContextHandler;
import org.eclipse.jetty.ee11.servlet.ServletHolder;
import org.eclipse.jetty.security.openid.OpenIdConfiguration;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);
addBean(server);
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public OpenIdConfiguration getOpenIdConfiguration()
{
String provider = getProvider();
String authEndpoint = provider + AUTH_PATH;
String tokenEndpoint = provider + TOKEN_PATH;
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
if (!clientId.equals(req.getParameter("client_id")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid client_id");
return;
}
String redirectUri = req.getParameter("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
LOG.warn("invalid redirectUri {}", redirectUri);
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String scopeString = req.getParameter("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(scopeString.split(" "));
if (!scopes.contains("openid"))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no openid scope");
return;
}
if (!"code".equals(req.getParameter("response_type")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "response_type must be code");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
if (preAuthedUser == null)
{
PrintWriter writer = resp.getWriter();
resp.setContentType("text/html");
writer.println("<h2>Login to OpenID Connect Provider</h2>");
writer.println("<form action=\"" + AUTH_PATH + "\" method=\"post\">");
writer.println("<input type=\"text\" autocomplete=\"off\" placeholder=\"Username\" name=\"username\" required>");
writer.println("<input type=\"hidden\" name=\"redirectUri\" value=\"" + redirectUri + "\">");
writer.println("<input type=\"hidden\" name=\"state\" value=\"" + state + "\">");
writer.println("<input type=\"submit\">");
writer.println("</form>");
}
else
{
redirectUser(resp, preAuthedUser, redirectUri, state);
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String redirectUri = req.getParameter("redirectUri");
if (!redirectUris.contains(redirectUri))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
String username = req.getParameter("username");
if (username == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username");
return;
}
User user = new User(username);
redirectUser(resp, user, redirectUri, state);
}
public void redirectUser(HttpServletResponse response, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
response.sendRedirect(response.encodeRedirectURL(redirectUri));
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
}
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
String code = req.getParameter("code");
if (!clientId.equals(req.getParameter("client_id")) ||
!clientSecret.equals(req.getParameter("client_secret")) ||
!redirectUris.contains(req.getParameter("redirect_uri")) ||
!"authorization_code".equals(req.getParameter("grant_type")) ||
code == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "bad auth request");
return;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid auth code");
return;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String response = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}
String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().println("logout success on end_session_endpoint");
return;
}
loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}
private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
resp.getWriter().write(discoveryDocument);
}
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}

View File

@ -56,6 +56,11 @@
<type>war</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-test-util</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.tests</groupId>
<artifactId>jetty-testers</artifactId>

View File

@ -17,8 +17,8 @@ import java.nio.file.Path;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.client.ContentResponse;
import org.eclipse.jetty.ee9.tests.distribution.openid.OpenIdProvider;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.tests.OpenIdProvider;
import org.eclipse.jetty.tests.distribution.AbstractJettyHomeTest;
import org.eclipse.jetty.tests.testers.JettyHomeTester;
import org.eclipse.jetty.tests.testers.Tester;

View File

@ -1,53 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee9.tests.distribution.openid;
import java.util.Base64;
/**
* A basic JWT encoder for testing purposes.
*/
public class JwtEncoder
{
private static final Base64.Encoder ENCODER = Base64.getUrlEncoder();
private static final String DEFAULT_HEADER = "{\"INFO\": \"this is not used or checked in our implementation\"}";
private static final String DEFAULT_SIGNATURE = "we do not validate signature as we use the authorization code flow";
public static String encode(String idToken)
{
return stripPadding(ENCODER.encodeToString(DEFAULT_HEADER.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(idToken.getBytes())) + "." +
stripPadding(ENCODER.encodeToString(DEFAULT_SIGNATURE.getBytes()));
}
private static String stripPadding(String paddedBase64)
{
return paddedBase64.split("=")[0];
}
/**
* Create a basic JWT for testing using argument supplied attributes.
*/
public static String createIdToken(String provider, String clientId, String subject, String name, long expiry)
{
return "{" +
"\"iss\": \"" + provider + "\"," +
"\"sub\": \"" + subject + "\"," +
"\"aud\": \"" + clientId + "\"," +
"\"exp\": " + expiry + "," +
"\"name\": \"" + name + "\"," +
"\"email\": \"" + name + "@example.com" + "\"" +
"}";
}
}

View File

@ -1,402 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.ee9.tests.distribution.openid;
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.ee9.servlet.ServletContextHandler;
import org.eclipse.jetty.ee9.servlet.ServletHolder;
import org.eclipse.jetty.security.openid.OpenIdConfiguration;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.statistic.CounterStatistic;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenIdProvider extends ContainerLifeCycle
{
private static final Logger LOG = LoggerFactory.getLogger(OpenIdProvider.class);
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
private static final String END_SESSION_PATH = "/end_session";
private final Map<String, User> issuedAuthCodes = new HashMap<>();
protected final String clientId;
protected final String clientSecret;
protected final List<String> redirectUris = new ArrayList<>();
private final ServerConnector connector;
private final Server server;
private int port = 0;
private String provider;
private User preAuthedUser;
private final CounterStatistic loggedInUsers = new CounterStatistic();
private long _idTokenDuration = Duration.ofSeconds(10).toMillis();
public static void main(String[] args) throws Exception
{
String clientId = "CLIENT_ID123";
String clientSecret = "PASSWORD123";
int port = 5771;
String redirectUri = "http://localhost:8080/j_security_check";
OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
openIdProvider.addRedirectUri(redirectUri);
openIdProvider.setPort(port);
openIdProvider.start();
try
{
openIdProvider.join();
}
finally
{
openIdProvider.stop();
}
}
public OpenIdProvider(String clientId, String clientSecret)
{
this.clientId = clientId;
this.clientSecret = clientSecret;
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new ConfigServlet()), CONFIG_PATH);
contextHandler.addServlet(new ServletHolder(new AuthEndpoint()), AUTH_PATH);
contextHandler.addServlet(new ServletHolder(new TokenEndpoint()), TOKEN_PATH);
contextHandler.addServlet(new ServletHolder(new EndSessionEndpoint()), END_SESSION_PATH);
server.setHandler(contextHandler);
addBean(server);
}
public void setIdTokenDuration(long duration)
{
_idTokenDuration = duration;
}
public long getIdTokenDuration()
{
return _idTokenDuration;
}
public void join() throws InterruptedException
{
server.join();
}
public OpenIdConfiguration getOpenIdConfiguration()
{
String provider = getProvider();
String authEndpoint = provider + AUTH_PATH;
String tokenEndpoint = provider + TOKEN_PATH;
return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
}
public CounterStatistic getLoggedInUsers()
{
return loggedInUsers;
}
@Override
protected void doStart() throws Exception
{
connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
public void setPort(int port)
{
if (isStarted())
throw new IllegalStateException();
this.port = port;
}
public void setUser(User user)
{
this.preAuthedUser = user;
}
public String getProvider()
{
if (!isStarted() && port == 0)
throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
public void addRedirectUri(String uri)
{
redirectUris.add(uri);
}
public class AuthEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
if (!clientId.equals(req.getParameter("client_id")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid client_id");
return;
}
String redirectUri = req.getParameter("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
LOG.warn("invalid redirectUri {}", redirectUri);
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String scopeString = req.getParameter("scope");
List<String> scopes = (scopeString == null) ? Collections.emptyList() : Arrays.asList(scopeString.split(" "));
if (!scopes.contains("openid"))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no openid scope");
return;
}
if (!"code".equals(req.getParameter("response_type")))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "response_type must be code");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
if (preAuthedUser == null)
{
PrintWriter writer = resp.getWriter();
resp.setContentType("text/html");
writer.println("<h2>Login to OpenID Connect Provider</h2>");
writer.println("<form action=\"" + AUTH_PATH + "\" method=\"post\">");
writer.println("<input type=\"text\" autocomplete=\"off\" placeholder=\"Username\" name=\"username\" required>");
writer.println("<input type=\"hidden\" name=\"redirectUri\" value=\"" + redirectUri + "\">");
writer.println("<input type=\"hidden\" name=\"state\" value=\"" + state + "\">");
writer.println("<input type=\"submit\">");
writer.println("</form>");
}
else
{
redirectUser(resp, preAuthedUser, redirectUri, state);
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String redirectUri = req.getParameter("redirectUri");
if (!redirectUris.contains(redirectUri))
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
String state = req.getParameter("state");
if (state == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
return;
}
String username = req.getParameter("username");
if (username == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username");
return;
}
User user = new User(username);
redirectUser(resp, user, redirectUri, state);
}
public void redirectUser(HttpServletResponse response, User user, String redirectUri, String state) throws IOException
{
String authCode = UUID.randomUUID().toString().replace("-", "");
issuedAuthCodes.put(authCode, user);
try
{
redirectUri += "?code=" + authCode + "&state=" + state;
response.sendRedirect(response.encodeRedirectURL(redirectUri));
}
catch (Throwable t)
{
issuedAuthCodes.remove(authCode);
throw t;
}
}
}
private class TokenEndpoint extends HttpServlet
{
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
String code = req.getParameter("code");
if (!clientId.equals(req.getParameter("client_id")) ||
!clientSecret.equals(req.getParameter("client_secret")) ||
!redirectUris.contains(req.getParameter("redirect_uri")) ||
!"authorization_code".equals(req.getParameter("grant_type")) ||
code == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "bad auth request");
return;
}
User user = issuedAuthCodes.remove(code);
if (user == null)
{
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid auth code");
return;
}
String accessToken = "ABCDEFG";
long accessTokenDuration = Duration.ofMinutes(10).toSeconds();
String response = "{" +
"\"access_token\": \"" + accessToken + "\"," +
"\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId, _idTokenDuration)) + "\"," +
"\"expires_in\": " + accessTokenDuration + "," +
"\"token_type\": \"Bearer\"" +
"}";
loggedInUsers.increment();
resp.setContentType("text/plain");
resp.getWriter().print(response);
}
}
private class EndSessionEndpoint extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String idToken = req.getParameter("id_token_hint");
if (idToken == null)
{
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "no id_token_hint");
return;
}
String logoutRedirect = req.getParameter("post_logout_redirect_uri");
if (logoutRedirect == null)
{
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().println("logout success on end_session_endpoint");
return;
}
loggedInUsers.decrement();
resp.setContentType("text/plain");
resp.sendRedirect(logoutRedirect);
}
}
private class ConfigServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
"\"authorization_endpoint\": \"" + provider + AUTH_PATH + "\"," +
"\"token_endpoint\": \"" + provider + TOKEN_PATH + "\"," +
"\"end_session_endpoint\": \"" + provider + END_SESSION_PATH + "\"," +
"}";
resp.getWriter().write(discoveryDocument);
}
}
public static class User
{
private final String subject;
private final String name;
public User(String name)
{
this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
public User(String subject, String name)
{
this.subject = subject;
this.name = name;
}
public String getName()
{
return name;
}
public String getSubject()
{
return subject;
}
public String getIdToken(String provider, String clientId, long duration)
{
long expiryTime = Instant.now().plusMillis(duration).getEpochSecond();
return JwtEncoder.createIdToken(provider, clientId, subject, name, expiryTime);
}
@Override
public boolean equals(Object obj)
{
if (!(obj instanceof User))
return false;
return Objects.equals(subject, ((User)obj).subject) && Objects.equals(name, ((User)obj).name);
}
@Override
public int hashCode()
{
return Objects.hash(subject, name);
}
}
}