Thread-safe ServiceLoader usage

Blend of pre-0.11.0 behavior that cached implementation instances and post-0.11.0 behavior using the JDK ServiceLoader to find/create instances of an SPI interface.  This change:

- Reinstates the <= 0.10.x behavior of caching application singleton service implementation instances in a thread-safe reference (previously an AtomicReference, but in this change, a ConcurrentMap).  If an app singleton instance is cached and found, it is returned to be (re)used immediately when requested.  This is ok for JJWT's purposes because all service implementations instances must be thread-safe application singletons by API contract/design, so caching them for repeated use is fine.

- Ensures that only if a service implementation instance is not in the app singleton cache, a new instance is located/created using a new JDK ServiceLoader instance, which doesn't require thread-safe considerations since it is used only in a single-threaded model for the short time it is used to discover a service implementation.  This PR/change removes the post-0.11.0 concurrent cache of ServiceLoader instances since they themselves are not designed to be thread-safe.

- Ensures that if a ServiceLoader discovers an implementation and returns a new instance, that instance is then cached as an application singleton in the aforementioned ConcurrentMap for continued reuse.

- Renames Services#loadFirst to Services#get to more accurately reflect calling expectations:  The fact that any 'loading' via the ServiceLoader may occur is not important for Services callers, and the previous method name was unnecessarily exposing internal implementation concepts.  This is safe to do in a point release (0.12.3 -> 0.12.4) because the Services class and its methods, while public, are in the `impl` module, only to be used internally for JJWT's purpose and never intended to be used by application developers.

- Updates all test methods to use the renamed method accordingly.

Fixes #873
This commit is contained in:
lhazlewood 2024-01-17 13:35:20 -08:00 committed by GitHub
parent 406f2f39df
commit d878404434
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 59 additions and 102 deletions

View File

@ -508,7 +508,7 @@ public class DefaultJwtBuilder implements JwtBuilder {
if (this.serializer == null) { // try to find one based on the services available
//noinspection unchecked
json(Services.loadFirst(Serializer.class));
json(Services.get(Serializer.class));
}
if (!Collections.isEmpty(claims)) { // normalize so we have one object to deal with:

View File

@ -370,7 +370,7 @@ public class DefaultJwtParserBuilder implements JwtParserBuilder {
if (this.deserializer == null) {
//noinspection unchecked
json(Services.loadFirst(Deserializer.class));
json(Services.get(Deserializer.class));
}
if (this.signingKeyResolver != null && this.signatureVerificationKey != null) {
String msg = "Both a 'signingKeyResolver and a 'verifyWith' key cannot be configured. " +

View File

@ -50,7 +50,7 @@ public abstract class AbstractParserBuilder<T, B extends ParserBuilder<T, B>> im
public final Parser<T> build() {
if (this.deserializer == null) {
//noinspection unchecked
this.deserializer = Services.loadFirst(Deserializer.class);
this.deserializer = Services.get(Deserializer.class);
}
return doBuild();
}

View File

@ -15,25 +15,24 @@
*/
package io.jsonwebtoken.impl.lang;
import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import static io.jsonwebtoken.lang.Collections.arrayToList;
/**
* Helper class for loading services from the classpath, using a {@link ServiceLoader}. Decouples loading logic for
* better separation of concerns and testability.
*/
public final class Services {
private static ConcurrentMap<Class<?>, ServiceLoader<?>> SERVICE_CACHE = new ConcurrentHashMap<>();
private static final ConcurrentMap<Class<?>, Object> SERVICES = new ConcurrentHashMap<>();
private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = arrayToList(new ClassLoaderAccessor[] {
private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = Arrays.asList(new ClassLoaderAccessor[]{
new ClassLoaderAccessor() {
@Override
public ClassLoader getClassLoader() {
@ -54,86 +53,57 @@ public final class Services {
}
});
private Services() {}
/**
* Loads and instantiates all service implementation of the given SPI class and returns them as a List.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return An unmodifiable list with an instance of all available implementations of the SPI. No guarantee is given
* on the order of implementations, if more than one.
*/
public static <T> List<T> loadAll(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");
ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {
List<T> implementations = new ArrayList<>();
for (T implementation : serviceLoader) {
implementations.add(implementation);
}
return implementations;
}
throw new UnavailableImplementationException(spi);
private Services() {
}
/**
* Loads the first available implementation the given SPI class from the classpath. Uses the {@link ServiceLoader}
* to find implementations. When multiple implementations are available it will return the first one that it
* encounters. There is no guarantee with regard to ordering.
* Returns the first available implementation for the given SPI class, checking an internal thread-safe cache first,
* and, if not found, using a {@link ServiceLoader} to find implementations. When multiple implementations are
* available it will return the first one that it encounters. There is no guarantee with regard to ordering.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return A new instance of the service.
* @throws UnavailableImplementationException When no implementation the SPI is available on the classpath.
* @return The first available instance of the service.
* @throws UnavailableImplementationException When no implementation of the SPI class can be found.
*/
public static <T> T loadFirst(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");
ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {
return serviceLoader.iterator().next();
public static <T> T get(Class<T> spi) {
// TODO: JDK8, replace this find/putIfAbsent logic with ConcurrentMap.computeIfAbsent
T instance = findCached(spi);
if (instance == null) {
instance = loadFirst(spi); // throws UnavailableImplementationException if not found, which is what we want
SERVICES.putIfAbsent(spi, instance); // cache if not already cached
}
throw new UnavailableImplementationException(spi);
return instance;
}
/**
* Returns a ServiceLoader for <code>spi</code> class, checking multiple classloaders. The ServiceLoader
* will be cached if it contains at least one implementation of the <code>spi</code> class.<BR>
*
* <b>NOTE:</b> Only the first Serviceloader will be cached.
* @param spi The interface or abstract class representing the service loader.
* @return A service loader, or null if no implementations are found
* @param <T> The type of the SPI.
*/
private static <T> ServiceLoader<T> serviceLoader(Class<T> spi) {
// TODO: JDK8, replace this get/putIfAbsent logic with ConcurrentMap.computeIfAbsent
ServiceLoader<T> serviceLoader = (ServiceLoader<T>) SERVICE_CACHE.get(spi);
if (serviceLoader != null) {
return serviceLoader;
private static <T> T findCached(Class<T> spi) {
Assert.notNull(spi, "Service interface cannot be null.");
Object obj = SERVICES.get(spi);
if (obj != null) {
return Assert.isInstanceOf(spi, obj, "Unexpected cached service implementation type.");
}
for (ClassLoaderAccessor classLoaderAccessor : CLASS_LOADER_ACCESSORS) {
serviceLoader = ServiceLoader.load(spi, classLoaderAccessor.getClassLoader());
if (serviceLoader.iterator().hasNext()) {
SERVICE_CACHE.putIfAbsent(spi, serviceLoader);
return serviceLoader;
}
}
return null;
}
private static <T> T loadFirst(Class<T> spi) {
for (ClassLoaderAccessor accessor : CLASS_LOADER_ACCESSORS) {
ServiceLoader<T> loader = ServiceLoader.load(spi, accessor.getClassLoader());
Assert.stateNotNull(loader, "JDK ServiceLoader#load should never return null.");
Iterator<T> i = loader.iterator();
Assert.stateNotNull(i, "JDK ServiceLoader#iterator() should never return null.");
if (i.hasNext()) {
return i.next();
}
}
throw new UnavailableImplementationException(spi);
}
/**
* Clears internal cache of ServiceLoaders. This is useful when testing, or for applications that dynamically
* Clears internal cache of service singletons. This is useful when testing, or for applications that dynamically
* change classloaders.
*/
public static void reload() {
SERVICE_CACHE.clear();
SERVICES.clear();
}
private interface ClassLoaderAccessor {

View File

@ -32,7 +32,7 @@ public final class JwksBridge {
@SuppressWarnings({"unchecked", "unused"}) // used via reflection by io.jsonwebtoken.security.Jwks
public static String UNSAFE_JSON(Jwk<?> jwk) {
Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer.class);
Serializer<Map<String, ?>> serializer = Services.get(Serializer.class);
Assert.stateNotNull(serializer, "Serializer lookup failed. Ensure JSON impl .jar is in the runtime classpath.");
NamedSerializer ser = new NamedSerializer("JWK", serializer);
ByteArrayOutputStream out = new ByteArrayOutputStream(512);

View File

@ -75,7 +75,7 @@ class JwtsTest {
}
static def toJson(def o) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
def out = new ByteArrayOutputStream()
serializer.serialize(o, out)
return Strings.utf8(out.toByteArray())
@ -1192,7 +1192,7 @@ class JwtsTest {
int j = jws.lastIndexOf('.')
def b64 = jws.substring(i, j)
def json = Strings.utf8(Decoders.BASE64URL.decode(b64))
def deser = Services.loadFirst(Deserializer)
def deser = Services.get(Deserializer)
def m = deser.deserialize(new StringReader(json)) as Map<String,?>
assertEquals aud, m.get('aud') // single string value

View File

@ -29,8 +29,8 @@ import static org.junit.Assert.fail
class RFC7515AppendixETest {
static final Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.loadFirst(Deserializer)
static final Serializer<Map<String, ?>> serializer = Services.get(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.get(Deserializer)
static byte[] ser(def value) {
ByteArrayOutputStream baos = new ByteArrayOutputStream(512)

View File

@ -100,11 +100,9 @@ class RFC7797Test {
def claims = Jwts.claims().subject('me').build()
ByteArrayOutputStream out = new ByteArrayOutputStream()
Services.loadFirst(Serializer).serialize(claims, out)
Services.get(Serializer).serialize(claims, out)
byte[] content = out.toByteArray()
//byte[] content = Services.loadFirst(Serializer).serialize(claims)
String s = Jwts.builder().signWith(key).content(content).encodePayload(false).compact()
// But verify with 3 types of sources: string, byte array, and two different kinds of InputStreams:

View File

@ -45,7 +45,7 @@ class DefaultJwtBuilderTest {
private DefaultJwtBuilder builder
private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()
@ -53,7 +53,7 @@ class DefaultJwtBuilderTest {
private static Map<String, ?> deser(byte[] data) {
def reader = Streams.reader(data)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(reader) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(reader) as Map<String, ?>
return m
}
@ -749,7 +749,7 @@ class DefaultJwtBuilderTest {
// so we need to check the raw payload:
def encoded = new JwtTokenizer().tokenize(Streams.reader(jwt)).getPayload()
byte[] bytes = Decoders.BASE64URL.decode(encoded)
def claims = Services.loadFirst(Deserializer).deserialize(Streams.reader(bytes))
def claims = Services.get(Deserializer).deserialize(Streams.reader(bytes))
assertEquals two, claims.aud
}

View File

@ -54,7 +54,7 @@ class DefaultJwtParserTest {
}
private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()

View File

@ -38,7 +38,7 @@ class RfcTests {
static final Map<String, ?> jsonToMap(String json) {
Reader r = new CharSequenceReader(json)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(r) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(r) as Map<String, ?>
return m
}

View File

@ -20,32 +20,21 @@ import io.jsonwebtoken.impl.DefaultStubService
import org.junit.After
import org.junit.Test
import static org.junit.Assert.*
import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertNotNull
class ServicesTest {
@Test
void testSuccessfulLoading() {
def factory = Services.loadFirst(StubService)
assertNotNull factory
assertEquals(DefaultStubService, factory.class)
def service = Services.get(StubService)
assertNotNull service
assertEquals(DefaultStubService, service.class)
}
@Test(expected = UnavailableImplementationException)
void testLoadFirstUnavailable() {
Services.loadFirst(NoService.class)
}
@Test
void testLoadAllAvailable() {
def list = Services.loadAll(StubService.class)
assertEquals 1, list.size()
assertTrue list[0] instanceof StubService
}
@Test(expected = UnavailableImplementationException)
void testLoadAllUnavailable() {
Services.loadAll(NoService.class)
void testLoadUnavailable() {
Services.get(NoService.class)
}
@Test

View File

@ -43,7 +43,7 @@ class RFC7518AppendixCTest {
}
private static final Map<String, ?> fromJson(String s) {
return Services.loadFirst(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
return Services.get(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
}
private static EcPrivateJwk readJwk(String json) {