add refresh token feature.

This commit is contained in:
eelhazati 2019-07-23 09:09:56 +01:00
parent dd0003a478
commit b6de1db857
8 changed files with 250 additions and 52 deletions

View File

@ -15,15 +15,15 @@ import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@Path("token") @Path("token")
public class TokenEndpoint { public class TokenEndpoint {
List<String> supportedGrantTypes = Collections.singletonList("authorization_code"); List<String> supportedGrantTypes = Arrays.asList("authorization_code", "refresh_token");
@Inject @Inject
private AppDataRepository appDataRepository; private AppDataRepository appDataRepository;

View File

@ -0,0 +1,87 @@
package com.baeldung.oauth2.authorization.server.handler;
import com.baeldung.oauth2.authorization.server.PEMKeyUtils;
import com.nimbusds.jose.*;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.eclipse.microprofile.config.Config;
import javax.inject.Inject;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Date;
import java.util.UUID;
public abstract class AbstractGrantTypeHandler implements AuthorizationGrantTypeHandler {
//Always RSA 256, but could be parametrized
protected JWSHeader jwsHeader = new JWSHeader.Builder(JWSAlgorithm.RS256).type(JOSEObjectType.JWT).build();
@Inject
protected Config config;
//30 min
protected Long expiresInMin = 30L;
protected JWSVerifier getJWSVerifier() throws Exception {
String verificationkey = config.getValue("verificationkey", String.class);
String pemEncodedRSAPublicKey = PEMKeyUtils.readKeyAsString(verificationkey);
RSAKey rsaPublicKey = (RSAKey) JWK.parseFromPEMEncodedObjects(pemEncodedRSAPublicKey);
return new RSASSAVerifier(rsaPublicKey);
}
protected JWSSigner getJwsSigner() throws Exception {
String signingkey = config.getValue("signingkey", String.class);
String pemEncodedRSAPrivateKey = PEMKeyUtils.readKeyAsString(signingkey);
RSAKey rsaKey = (RSAKey) JWK.parseFromPEMEncodedObjects(pemEncodedRSAPrivateKey);
return new RSASSASigner(rsaKey.toRSAPrivateKey());
}
protected String getAccessToken(String clientId, String subject, String approvedScope) throws Exception {
//4. Signing
JWSSigner jwsSigner = getJwsSigner();
Instant now = Instant.now();
//Long expiresInMin = 30L;
Date expirationTime = Date.from(now.plus(expiresInMin, ChronoUnit.MINUTES));
//3. JWT Payload or claims
JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder()
.issuer("http://localhost:9080")
.subject(subject)
.claim("upn", subject)
.claim("client_id", clientId)
.audience("http://localhost:9280")
.claim("scope", approvedScope)
.claim("groups", Arrays.asList(approvedScope.split(" ")))
.expirationTime(expirationTime) // expires in 30 minutes
.notBeforeTime(Date.from(now))
.issueTime(Date.from(now))
.jwtID(UUID.randomUUID().toString())
.build();
SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaims);
signedJWT.sign(jwsSigner);
return signedJWT.serialize();
}
protected String getRefreshToken(String clientId, String subject, String approvedScope) throws Exception {
JWSSigner jwsSigner = getJwsSigner();
Instant now = Instant.now();
//6.Build refresh token
JWTClaimsSet refreshTokenClaims = new JWTClaimsSet.Builder()
.subject(subject)
.claim("client_id", clientId)
.claim("scope", approvedScope)
//refresh token for 1 day.
.expirationTime(Date.from(now.plus(1, ChronoUnit.DAYS)))
.build();
SignedJWT signedRefreshToken = new SignedJWT(jwsHeader, refreshTokenClaims);
signedRefreshToken.sign(jwsSigner);
return signedRefreshToken.serialize();
}
}

View File

@ -1,18 +1,7 @@
package com.baeldung.oauth2.authorization.server.handler; package com.baeldung.oauth2.authorization.server.handler;
import com.baeldung.oauth2.authorization.server.PEMKeyUtils;
import com.baeldung.oauth2.authorization.server.model.AuthorizationCode; import com.baeldung.oauth2.authorization.server.model.AuthorizationCode;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.eclipse.microprofile.config.Config;
import javax.inject.Inject;
import javax.inject.Named; import javax.inject.Named;
import javax.json.Json; import javax.json.Json;
import javax.json.JsonObject; import javax.json.JsonObject;
@ -20,22 +9,14 @@ import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext; import javax.persistence.PersistenceContext;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import java.time.Instant;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Date;
import java.util.UUID;
@Named("authorization_code") @Named("authorization_code")
public class AuthorizationCodeGrantTypeHandler implements AuthorizationGrantTypeHandler { public class AuthorizationCodeGrantTypeHandler extends AbstractGrantTypeHandler {
@PersistenceContext @PersistenceContext
private EntityManager entityManager; private EntityManager entityManager;
@Inject
private Config config;
@Override @Override
public JsonObject createAccessToken(String clientId, MultivaluedMap<String, String> params) throws Exception { public JsonObject createAccessToken(String clientId, MultivaluedMap<String, String> params) throws Exception {
//1. code is required //1. code is required
@ -58,42 +39,16 @@ public class AuthorizationCodeGrantTypeHandler implements AuthorizationGrantType
throw new WebApplicationException("invalid_grant"); throw new WebApplicationException("invalid_grant");
} }
//JWT Header
JWSHeader jwsHeader = new JWSHeader.Builder(JWSAlgorithm.RS256).type(JOSEObjectType.JWT).build();
Instant now = Instant.now();
Long expiresInMin = 30L;
Date expiresIn = Date.from(now.plus(expiresInMin, ChronoUnit.MINUTES));
//3. JWT Payload or claims //3. JWT Payload or claims
JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder() String accessToken = getAccessToken(clientId, authorizationCode.getUserId(), authorizationCode.getApprovedScopes());
.issuer("http://localhost:9080") String refreshToken = getRefreshToken(clientId, authorizationCode.getUserId(), authorizationCode.getApprovedScopes());
.subject(authorizationCode.getUserId())
.claim("upn", authorizationCode.getUserId())
.audience("http://localhost:9280")
.claim("scope", authorizationCode.getApprovedScopes())
.claim("groups", Arrays.asList(authorizationCode.getApprovedScopes().split(" ")))
.expirationTime(expiresIn) // expires in 30 minutes
.notBeforeTime(Date.from(now))
.issueTime(Date.from(now))
.jwtID(UUID.randomUUID().toString())
.build();
SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaims);
//4. Signing
String signingkey = config.getValue("signingkey", String.class);
String pemEncodedRSAPrivateKey = PEMKeyUtils.readKeyAsString(signingkey);
RSAKey rsaKey = (RSAKey) JWK.parseFromPEMEncodedObjects(pemEncodedRSAPrivateKey);
signedJWT.sign(new RSASSASigner(rsaKey.toRSAPrivateKey()));
//5. Finally the JWT access token
String accessToken = signedJWT.serialize();
return Json.createObjectBuilder() return Json.createObjectBuilder()
.add("token_type", "Bearer") .add("token_type", "Bearer")
.add("access_token", accessToken) .add("access_token", accessToken)
.add("expires_in", expiresInMin * 60) .add("expires_in", expiresInMin * 60)
.add("scope", authorizationCode.getApprovedScopes()) .add("scope", authorizationCode.getApprovedScopes())
.add("refresh_token", refreshToken)
.build(); .build();
} }
} }

View File

@ -0,0 +1,64 @@
package com.baeldung.oauth2.authorization.server.handler;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jwt.SignedJWT;
import javax.inject.Named;
import javax.json.Json;
import javax.json.JsonObject;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MultivaluedMap;
import java.util.*;
@Named("refresh_token")
public class RefreshTokenGrantTypeHandler extends AbstractGrantTypeHandler {
@Override
public JsonObject createAccessToken(String clientId, MultivaluedMap<String, String> params) throws Exception {
String refreshToken = params.getFirst("refresh_token");
if (refreshToken == null || "".equals(refreshToken)) {
throw new WebApplicationException("invalid_grant");
}
//Decode refresh token
SignedJWT signedRefreshToken = SignedJWT.parse(refreshToken);
JWSVerifier verifier = getJWSVerifier();
if (!signedRefreshToken.verify(verifier)) {
throw new WebApplicationException("Invalid refresh token.");
}
if (!(new Date().before(signedRefreshToken.getJWTClaimsSet().getExpirationTime()))) {
throw new WebApplicationException("Refresh token expired.");
}
String refreshTokenClientId = signedRefreshToken.getJWTClaimsSet().getStringClaim("client_id");
if (!clientId.equals(refreshTokenClientId)) {
throw new WebApplicationException("Invalid client_id.");
}
//At this point, the refresh token is valid and not yet expired
//So create a new access token from it.
String subject = signedRefreshToken.getJWTClaimsSet().getSubject();
String approvedScopes = signedRefreshToken.getJWTClaimsSet().getStringClaim("scope");
String finalScope = approvedScopes;
String requestedScopes = params.getFirst("scope");
if (requestedScopes != null && !requestedScopes.isEmpty()) {
Set<String> allowedScopes = new LinkedHashSet<>();
Set<String> rScopes = new HashSet(Arrays.asList(requestedScopes.split(" ")));
Set<String> aScopes = new HashSet(Arrays.asList(approvedScopes.split(" ")));
for (String scope : rScopes) {
if (aScopes.contains(scope)) allowedScopes.add(scope);
}
finalScope = String.join(" ", allowedScopes);
}
String accessToken = getAccessToken(clientId, subject, finalScope);
return Json.createObjectBuilder()
.add("token_type", "Bearer")
.add("access_token", accessToken)
.add("expires_in", expiresInMin * 60)
.add("scope", finalScope)
.add("refresh_token", refreshToken)
.build();
}
}

View File

@ -0,0 +1,23 @@
package com.baeldung.oauth2.client;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Base64;
public abstract class AbstractServlet extends HttpServlet {
protected void dispatch(String location, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
RequestDispatcher requestDispatcher = request.getRequestDispatcher(location);
requestDispatcher.forward(request, response);
}
protected String getAuthorizationHeaderValue(String clientId, String clientSecret) {
String token = clientId + ":" + clientSecret;
String encodedString = Base64.getEncoder().encodeToString(token.getBytes());
return "Basic " + encodedString;
}
}

View File

@ -29,6 +29,9 @@ public class CallbackServlet extends HttpServlet {
@Override @Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
String clientId = config.getValue("client.clientId", String.class);
String clientSecret = config.getValue("client.clientSecret", String.class);
//Error: //Error:
String error = request.getParameter("error"); String error = request.getParameter("error");
if (error != null) { if (error != null) {

View File

@ -0,0 +1,52 @@
package com.baeldung.oauth2.client;
import org.eclipse.microprofile.config.Config;
import javax.inject.Inject;
import javax.json.JsonObject;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
@WebServlet(urlPatterns = "/refreshtoken")
public class RefreshTokenServlet extends AbstractServlet {
@Inject
private Config config;
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
String clientId = config.getValue("client.clientId", String.class);
String clientSecret = config.getValue("client.clientSecret", String.class);
JsonObject actualTokenResponse = (JsonObject) request.getSession().getAttribute("tokenResponse");
Client client = ClientBuilder.newClient();
WebTarget target = client.target(config.getValue("provider.tokenUri", String.class));
Form form = new Form();
form.param("grant_type", "refresh_token");
form.param("refresh_token", actualTokenResponse.getString("refresh_token"));
String scope = request.getParameter("scope");
if (scope != null && !scope.isEmpty()) {
form.param("scope", scope);
}
JsonObject tokenResponse = target.request(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeaderValue(clientId, clientSecret))
.post(Entity.entity(form, MediaType.APPLICATION_FORM_URLENCODED_TYPE), JsonObject.class);
request.getSession().setAttribute("tokenResponse", tokenResponse);
dispatch("/", request, response);
}
}

View File

@ -10,6 +10,7 @@
body { body {
margin: 0px; margin: 0px;
} }
input[type=text], input[type=password] { input[type=text], input[type=password] {
width: 75%; width: 75%;
padding: 4px 0px; padding: 4px 0px;
@ -17,6 +18,7 @@
border: 1px solid #502bcc; border: 1px solid #502bcc;
box-sizing: border-box; box-sizing: border-box;
} }
.container-error { .container-error {
padding: 16px; padding: 16px;
border: 1px solid #cc102a; border: 1px solid #cc102a;
@ -25,6 +27,7 @@
margin-left: 25px; margin-left: 25px;
margin-bottom: 25px; margin-bottom: 25px;
} }
.container { .container {
padding: 16px; padding: 16px;
border: 1px solid #130ecc; border: 1px solid #130ecc;
@ -81,8 +84,20 @@
<li>access_token: ${tokenResponse.getString("access_token")}</li> <li>access_token: ${tokenResponse.getString("access_token")}</li>
<li>scope: ${tokenResponse.getString("scope")}</li> <li>scope: ${tokenResponse.getString("scope")}</li>
<li>Expires in (s): ${tokenResponse.getInt("expires_in")}</li> <li>Expires in (s): ${tokenResponse.getInt("expires_in")}</li>
<li>refresh_token: ${tokenResponse.getString("refresh_token")}</li>
</ul> </ul>
</div> </div>
<div class="container">
<span><h4>Refresh Token</h4></span>
<hr>
<ul>
<li><a href="refreshtoken">Refresh token (original scope)</a></li>
<li><a href="refreshtoken?scope=resource.read">Refresh token (scope: resource.read)</a></li>
<li><a href="refreshtoken?scope=resource.write">Refresh token (scope: resource.write)</a></li>
</ul>
</div>
<div class="container"> <div class="container">
<span><h4>OAuth2 Resource Server Call</h4></span> <span><h4>OAuth2 Resource Server Call</h4></span>
<hr> <hr>
@ -90,7 +105,6 @@
<li><a href="downstream?action=read">Read Protected Resource</a></li> <li><a href="downstream?action=read">Read Protected Resource</a></li>
<li><a href="downstream?action=write">Write Protected Resource</a></li> <li><a href="downstream?action=write">Write Protected Resource</a></li>
</ul> </ul>
</div> </div>
</body> </body>