NIFI-2621 - Generating unique serial numbers for certificates

This closes #909.

Signed-off-by: Andy LoPresto <alopresto@apache.org>
This commit is contained in:
Bryan Rosander 2016-08-22 12:57:37 -04:00 committed by Andy LoPresto
parent 6e82ec738c
commit 23350543ff
No known key found for this signature in database
GPG Key ID: 3C6EF65B2F7DEF69
2 changed files with 95 additions and 3 deletions

View File

@ -74,9 +74,23 @@ import java.util.concurrent.TimeUnit;
public final class CertificateUtils {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtils.class);
private static final String PEER_NOT_AUTHENTICATED_MSG = "peer not authenticated";
private static final Map<ASN1ObjectIdentifier, Integer> dnOrderMap = createDnOrderMap();
/**
* The time in milliseconds that the last unique serial number was generated
*/
private static long lastSerialNumberMillis = 0L;
/**
* An incrementor to add uniqueness to serial numbers generated in the same millisecond
*/
private static int serialNumberIncrementor = 0;
/**
* BigInteger value to use for the base of the unique serial number
*/
private static BigInteger millisecondBigInteger;
private static Map<ASN1ObjectIdentifier, Integer> createDnOrderMap() {
Map<ASN1ObjectIdentifier, Integer> orderMap = new HashMap<>();
int count = 0;
@ -438,6 +452,29 @@ public final class CertificateUtils {
return new X500Name(rdns.toArray(new RDN[rdns.size()]));
}
/**
* Generates a unique serial number by using the current time in milliseconds left shifted 32 bits (to make room for incrementor) with an incrementor added
*
* @return a unique serial number (technically unique to this classloader)
*/
protected static synchronized BigInteger getUniqueSerialNumber() {
final long currentTimeMillis = System.currentTimeMillis();
final int incrementorValue;
if (lastSerialNumberMillis != currentTimeMillis) {
// We can only get into this block once per millisecond
millisecondBigInteger = BigInteger.valueOf(currentTimeMillis).shiftLeft(32);
lastSerialNumberMillis = currentTimeMillis;
incrementorValue = 0;
serialNumberIncrementor = 1;
} else {
// Already created at least one serial number this millisecond
incrementorValue = serialNumberIncrementor++;
}
return millisecondBigInteger.add(BigInteger.valueOf(incrementorValue));
}
/**
* Generates a self-signed {@link X509Certificate} suitable for use as a Certificate Authority.
*
@ -458,7 +495,7 @@ public final class CertificateUtils {
X509v3CertificateBuilder certBuilder = new X509v3CertificateBuilder(
reverseX500Name(new X500Name(dn)),
BigInteger.valueOf(System.currentTimeMillis()),
getUniqueSerialNumber(),
startDate, endDate,
reverseX500Name(new X500Name(dn)),
subPubKeyInfo);
@ -507,7 +544,7 @@ public final class CertificateUtils {
X509v3CertificateBuilder certBuilder = new X509v3CertificateBuilder(
reverseX500Name(new X500Name(issuer.getSubjectX500Principal().getName())),
BigInteger.valueOf(System.currentTimeMillis()),
getUniqueSerialNumber(),
startDate, endDate,
reverseX500Name(new X500Name(dn)),
subPubKeyInfo);

View File

@ -40,9 +40,16 @@ import java.security.SignatureException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.Callable
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executors
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertTrue
@RunWith(JUnit4.class)
class CertificateUtilsTest extends GroovyTestCase {
@ -497,4 +504,52 @@ class CertificateUtilsTest extends GroovyTestCase {
assertEquals("$cn,$l,$st,$o,$ou,$c,$street,$dc,$uid,$surname,$givenName,$initials".toString(),
CertificateUtils.reorderDn("$surname,$st,$o,$initials,$givenName,$uid,$street,$c,$cn,$ou,$l,$dc"));
}
@Test
public void testUniqueSerialNumbers() {
def running = new AtomicBoolean(true);
def executorService = Executors.newCachedThreadPool()
def serialNumbers = Collections.newSetFromMap(new ConcurrentHashMap())
try {
def futures = new ArrayList<Future>()
for (int i = 0; i < 8; i++) {
futures.add(executorService.submit(new Callable<Integer>() {
@Override
Integer call() throws Exception {
int count = 0;
while (running.get()) {
def before = System.currentTimeMillis()
def serialNumber = CertificateUtils.getUniqueSerialNumber()
def after = System.currentTimeMillis()
def serialNumberMillis = serialNumber.shiftRight(32)
assertTrue(serialNumberMillis >= before)
assertTrue(serialNumberMillis <= after)
assertTrue(serialNumbers.add(serialNumber))
count++;
}
return count;
}
}));
}
Thread.sleep(1000)
running.set(false)
def totalRuns = 0;
for (int i = 0; i < futures.size(); i++) {
try {
def numTimes = futures.get(i).get()
logger.info("future $i executed $numTimes times")
totalRuns += numTimes;
} catch (ExecutionException e) {
throw e.getCause()
}
}
logger.info("Generated ${serialNumbers.size()} unique serial numbers")
assertEquals(totalRuns, serialNumbers.size())
} finally {
executorService.shutdown()
}
}
}