Issue #5535 - Adding regex include/exclude of Protocols to SslContextFactory

Signed-off-by: Joakim Erdfelt <joakim.erdfelt@gmail.com>
This commit is contained in:
Joakim Erdfelt 2020-10-30 10:29:35 -05:00
parent c969fba71a
commit 074b4f90f7
No known key found for this signature in database
GPG Key ID: 2D0E1FB8FE4B68B4
2 changed files with 68 additions and 68 deletions

View File

@ -47,14 +47,12 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import javax.net.ssl.CertPathTrustManagerParameters; import javax.net.ssl.CertPathTrustManagerParameters;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
@ -140,7 +138,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
private final Set<String> _excludeProtocols = new LinkedHashSet<>(); private final Set<String> _excludeProtocols = new LinkedHashSet<>();
private final Set<String> _includeProtocols = new LinkedHashSet<>(); private final Set<String> _includeProtocols = new LinkedHashSet<>();
private final Set<String> _excludeCipherSuites = new LinkedHashSet<>(); private final Set<String> _excludeCipherSuites = new LinkedHashSet<>();
private final List<String> _includeCipherSuites = new ArrayList<>(); private final Set<String> _includeCipherSuites = new LinkedHashSet<>();
private final Map<String, X509> _aliasX509 = new HashMap<>(); private final Map<String, X509> _aliasX509 = new HashMap<>();
private final Map<String, X509> _certHosts = new HashMap<>(); private final Map<String, X509> _certHosts = new HashMap<>();
private final Map<String, X509> _certWilds = new HashMap<>(); private final Map<String, X509> _certWilds = new HashMap<>();
@ -526,6 +524,8 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* You can either use the exact Protocol name or a a regular expression.
*
* @param protocols The array of protocol names to exclude from * @param protocols The array of protocol names to exclude from
* {@link SSLEngine#setEnabledProtocols(String[])} * {@link SSLEngine#setEnabledProtocols(String[])}
*/ */
@ -536,7 +536,9 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* @param protocol Protocol names to add to {@link SSLEngine#setEnabledProtocols(String[])} * You can either use the exact Protocol name or a a regular expression.
*
* @param protocol Protocol name patterns to add to {@link SSLEngine#setEnabledProtocols(String[])}
*/ */
public void addExcludeProtocols(String... protocol) public void addExcludeProtocols(String... protocol)
{ {
@ -544,7 +546,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* @return The array of protocol names to include in * @return The array of protocol name patterns to include in
* {@link SSLEngine#setEnabledProtocols(String[])} * {@link SSLEngine#setEnabledProtocols(String[])}
*/ */
@ManagedAttribute("The included TLS protocols") @ManagedAttribute("The included TLS protocols")
@ -554,7 +556,9 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* @param protocols The array of protocol names to include in * You can either use the exact Protocol name or a a regular expression.
*
* @param protocols The array of protocol name patterns to include in
* {@link SSLEngine#setEnabledProtocols(String[])} * {@link SSLEngine#setEnabledProtocols(String[])}
*/ */
public void setIncludeProtocols(String... protocols) public void setIncludeProtocols(String... protocols)
@ -564,7 +568,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* @return The array of cipher suite names to exclude from * @return The array of cipher suite name patterns to exclude from
* {@link SSLEngine#setEnabledCipherSuites(String[])} * {@link SSLEngine#setEnabledCipherSuites(String[])}
*/ */
@ManagedAttribute("The excluded cipher suites") @ManagedAttribute("The excluded cipher suites")
@ -574,7 +578,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* You can either use the exact cipher suite name or a a regular expression. * You can either use the exact Cipher suite name or a a regular expression.
* *
* @param cipherSuites The array of cipher suite names to exclude from * @param cipherSuites The array of cipher suite names to exclude from
* {@link SSLEngine#setEnabledCipherSuites(String[])} * {@link SSLEngine#setEnabledCipherSuites(String[])}
@ -586,6 +590,8 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* You can either use the exact Cipher suite name or a a regular expression.
*
* @param cipher Cipher names to add to {@link SSLEngine#setEnabledCipherSuites(String[])} * @param cipher Cipher names to add to {@link SSLEngine#setEnabledCipherSuites(String[])}
*/ */
public void addExcludeCipherSuites(String... cipher) public void addExcludeCipherSuites(String... cipher)
@ -594,7 +600,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* @return The array of cipher suite names to include in * @return The array of Cipher suite names to include in
* {@link SSLEngine#setEnabledCipherSuites(String[])} * {@link SSLEngine#setEnabledCipherSuites(String[])}
*/ */
@ManagedAttribute("The included cipher suites") @ManagedAttribute("The included cipher suites")
@ -604,7 +610,7 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
} }
/** /**
* You can either use the exact cipher suite name or a a regular expression. * You can either use the exact Cipher suite name or a a regular expression.
* *
* @param cipherSuites The array of cipher suite names to include in * @param cipherSuites The array of cipher suite names to include in
* {@link SSLEngine#setEnabledCipherSuites(String[])} * {@link SSLEngine#setEnabledCipherSuites(String[])}
@ -1357,28 +1363,10 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
*/ */
public void selectProtocols(String[] enabledProtocols, String[] supportedProtocols) public void selectProtocols(String[] enabledProtocols, String[] supportedProtocols)
{ {
Set<String> selectedProtocols = new LinkedHashSet<>(); List<String> selectedProtocols = processIncludeExcludePatterns("Protocols", enabledProtocols, supportedProtocols, _includeProtocols, _excludeProtocols);
// Set the starting protocols - either from the included or enabled list
if (!_includeProtocols.isEmpty())
{
// Use only the supported included protocols
for (String protocol : _includeProtocols)
{
if (Arrays.asList(supportedProtocols).contains(protocol))
selectedProtocols.add(protocol);
else
LOG.info("Protocol {} not supported in {}", protocol, Arrays.asList(supportedProtocols));
}
}
else
selectedProtocols.addAll(Arrays.asList(enabledProtocols));
// Remove any excluded protocols
selectedProtocols.removeAll(_excludeProtocols);
if (selectedProtocols.isEmpty()) if (selectedProtocols.isEmpty())
LOG.warn("No selected protocols from {}", Arrays.asList(supportedProtocols)); LOG.warn("No selected Protocols from {}", Arrays.asList(supportedProtocols));
_selectedProtocols = selectedProtocols.toArray(new String[0]); _selectedProtocols = selectedProtocols.toArray(new String[0]);
} }
@ -1393,18 +1381,10 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
*/ */
protected void selectCipherSuites(String[] enabledCipherSuites, String[] supportedCipherSuites) protected void selectCipherSuites(String[] enabledCipherSuites, String[] supportedCipherSuites)
{ {
List<String> selectedCiphers = new ArrayList<>(); List<String> selectedCiphers = processIncludeExcludePatterns("Cipher Suite", enabledCipherSuites, supportedCipherSuites, _includeCipherSuites, _excludeCipherSuites);
// Set the starting ciphers - either from the included or enabled list
if (_includeCipherSuites.isEmpty())
selectedCiphers.addAll(Arrays.asList(enabledCipherSuites));
else
processIncludeCipherSuites(supportedCipherSuites, selectedCiphers);
removeExcludedCipherSuites(selectedCiphers);
if (selectedCiphers.isEmpty()) if (selectedCiphers.isEmpty())
LOG.warn("No supported ciphers from {}", Arrays.asList(supportedCipherSuites)); LOG.warn("No supported Cipher Suite from {}", Arrays.asList(supportedCipherSuites));
Comparator<String> comparator = getCipherComparator(); Comparator<String> comparator = getCipherComparator();
if (comparator != null) if (comparator != null)
@ -1417,39 +1397,58 @@ public class SslContextFactory extends AbstractLifeCycle implements Dumpable
_selectedCipherSuites = selectedCiphers.toArray(new String[0]); _selectedCipherSuites = selectedCiphers.toArray(new String[0]);
} }
protected void processIncludeCipherSuites(String[] supportedCipherSuites, List<String> selectedCiphers) private List<String> processIncludeExcludePatterns(String type, String[] enabled, String[] supported, Set<String> included, Set<String> excluded)
{ {
for (String cipherSuite : _includeCipherSuites) List<String> selected = new ArrayList<>();
// Set the starting list - either from the included or enabled list
if (included.isEmpty())
{ {
Pattern p = Pattern.compile(cipherSuite); selected.addAll(Arrays.asList(enabled));
}
else
{
// process include patterns
for (String includedItem : included)
{
Pattern pattern = Pattern.compile(includedItem);
boolean added = false; boolean added = false;
for (String supportedCipherSuite : supportedCipherSuites) for (String supportedItem : supported)
{ {
Matcher m = p.matcher(supportedCipherSuite); if (pattern.matcher(supportedItem).matches())
if (m.matches())
{ {
added = true; added = true;
selectedCiphers.add(supportedCipherSuite); selected.add(supportedItem);
} }
} }
if (!added) if (!added)
LOG.info("No Cipher matching '{}' is supported", cipherSuite); LOG.info("No {} matching '{}' is supported", type, includedItem);
} }
} }
// process exclude patterns
for (String excludedItem : excluded)
{
Pattern pattern = Pattern.compile(excludedItem);
selected.removeIf(selectedItem -> pattern.matcher(selectedItem).matches());
}
return selected;
}
/**
* @deprecated no replacement
*/
@Deprecated
protected void processIncludeCipherSuites(String[] supportedCipherSuites, List<String> selectedCiphers)
{
}
/**
* @deprecated no replacement
*/
@Deprecated
protected void removeExcludedCipherSuites(List<String> selectedCiphers) protected void removeExcludedCipherSuites(List<String> selectedCiphers)
{ {
for (String excludeCipherSuite : _excludeCipherSuites)
{
Pattern excludeCipherPattern = Pattern.compile(excludeCipherSuite);
for (Iterator<String> i = selectedCiphers.iterator(); i.hasNext(); )
{
String selectedCipherSuite = i.next();
Matcher m = excludeCipherPattern.matcher(selectedCipherSuite);
if (m.matches())
i.remove();
}
}
} }
/** /**

View File

@ -102,11 +102,12 @@ public class SslContextFactoryTest
SslContextFactory.Server cf = new SslContextFactory.Server(); SslContextFactory.Server cf = new SslContextFactory.Server();
cf.setKeyStorePassword("storepwd"); cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd"); cf.setKeyManagerPassword("keypwd");
cf.setExcludeProtocols("TLSv1", "TLSv1.1"); cf.setExcludeProtocols("TLSv1\\.?[01]?");
cf.start(); cf.start();
// Confirm behavior in engine // Confirm behavior in engine
assertThat(cf.newSSLEngine().getEnabledProtocols(), not(hasItemInArray("TLSv1.1"))); assertThat(cf.newSSLEngine().getEnabledProtocols(), not(hasItemInArray("TLSv1.1")));
assertThat(cf.newSSLEngine().getEnabledProtocols(), not(hasItemInArray("TLSv1")));
// Confirm output in dump // Confirm output in dump
List<SslSelectionDump> dumps = cf.selectionDump(); List<SslSelectionDump> dumps = cf.selectionDump();
@ -125,7 +126,7 @@ public class SslContextFactoryTest
assertThat("Enabled Protocols TLSv1.1 count", countTls11Enabled, is(0L)); assertThat("Enabled Protocols TLSv1.1 count", countTls11Enabled, is(0L));
assertThat("Disabled Protocols TLSv1.1 count", countTls11Disabled, is(1L)); assertThat("Disabled Protocols TLSv1.1 count", countTls11Disabled, is(1L));
// Uncomment to show in console. // Uncomment to show dump in console.
// cf.dump(System.out, ""); // cf.dump(System.out, "");
} }