Refactor CrossOriginFilter with small perf improvements (#4672)

* Refactor CrossOriginFilter

Signed-off-by: Denny Abraham Cheriyan <dennyac@gmail.com>

* Fix checkstyle violation

Signed-off-by: Denny Abraham Cheriyan <dennyac@gmail.com>
This commit is contained in:
Denny Abraham Cheriyan 2020-03-16 14:51:00 +05:30 committed by Greg Wilkins
parent b497827df0
commit ce8d2ef168
1 changed files with 26 additions and 25 deletions

View File

@ -23,7 +23,9 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
@ -157,8 +159,10 @@ public class CrossOriginFilter implements Filter
private boolean anyOriginAllowed;
private boolean anyTimingOriginAllowed;
private boolean anyHeadersAllowed;
private List<String> allowedOrigins = new ArrayList<String>();
private List<String> allowedTimingOrigins = new ArrayList<String>();
private Set<String> allowedOrigins = new HashSet<String>();
private List<Pattern> allowedOriginPatterns = new ArrayList<Pattern>();
private Set<String> allowedTimingOrigins = new HashSet<String>();
private List<Pattern> allowedTimingOriginPatterns = new ArrayList<Pattern>();
private List<String> allowedMethods = new ArrayList<String>();
private List<String> allowedHeaders = new ArrayList<String>();
private List<String> exposedHeaders = new ArrayList<String>();
@ -172,8 +176,8 @@ public class CrossOriginFilter implements Filter
String allowedOriginsConfig = config.getInitParameter(ALLOWED_ORIGINS_PARAM);
String allowedTimingOriginsConfig = config.getInitParameter(ALLOWED_TIMING_ORIGINS_PARAM);
anyOriginAllowed = generateAllowedOrigins(allowedOrigins, allowedOriginsConfig, DEFAULT_ALLOWED_ORIGINS);
anyTimingOriginAllowed = generateAllowedOrigins(allowedTimingOrigins, allowedTimingOriginsConfig, DEFAULT_ALLOWED_TIMING_ORIGINS);
anyOriginAllowed = generateAllowedOrigins(allowedOrigins, allowedOriginPatterns, allowedOriginsConfig, DEFAULT_ALLOWED_ORIGINS);
anyTimingOriginAllowed = generateAllowedOrigins(allowedTimingOrigins, allowedTimingOriginPatterns, allowedTimingOriginsConfig, DEFAULT_ALLOWED_TIMING_ORIGINS);
String allowedMethodsConfig = config.getInitParameter(ALLOWED_METHODS_PARAM);
if (allowedMethodsConfig == null)
@ -235,7 +239,7 @@ public class CrossOriginFilter implements Filter
}
}
private boolean generateAllowedOrigins(List<String> allowedOriginStore, String allowedOriginsConfig, String defaultOrigin)
private boolean generateAllowedOrigins(Set<String> allowedOriginStore, List<Pattern> allowedOriginPatternStore, String allowedOriginsConfig, String defaultOrigin)
{
if (allowedOriginsConfig == null)
allowedOriginsConfig = defaultOrigin;
@ -247,8 +251,13 @@ public class CrossOriginFilter implements Filter
if (ANY_ORIGIN.equals(allowedOrigin))
{
allowedOriginStore.clear();
allowedOriginPatternStore.clear();
return true;
}
else if (allowedOrigin.contains("*"))
{
allowedOriginPatternStore.add(Pattern.compile(parseAllowedWildcardOriginToRegex(allowedOrigin)));
}
else
{
allowedOriginStore.add(allowedOrigin);
@ -270,7 +279,7 @@ public class CrossOriginFilter implements Filter
// Is it a cross origin request ?
if (origin != null && isEnabled(request))
{
if (anyOriginAllowed || originMatches(allowedOrigins, origin))
if (anyOriginAllowed || originMatches(allowedOrigins, allowedOriginPatterns, origin))
{
if (isSimpleRequest(request))
{
@ -292,7 +301,7 @@ public class CrossOriginFilter implements Filter
handleSimpleResponse(request, response, origin);
}
if (anyTimingOriginAllowed || originMatches(allowedTimingOrigins, origin))
if (anyTimingOriginAllowed || originMatches(allowedTimingOrigins, allowedTimingOriginPatterns, origin))
{
response.setHeader(TIMING_ALLOW_ORIGIN_HEADER, origin);
}
@ -330,7 +339,7 @@ public class CrossOriginFilter implements Filter
return true;
}
private boolean originMatches(List<String> allowedOrigins, String originList)
private boolean originMatches(Set<String> allowedOrigins, List<Pattern> allowedOriginPatterns, String originList)
{
if (originList.trim().length() == 0)
return false;
@ -341,30 +350,18 @@ public class CrossOriginFilter implements Filter
if (origin.trim().length() == 0)
continue;
for (String allowedOrigin : allowedOrigins)
if (allowedOrigins.contains(origin))
return true;
for (Pattern allowedOrigin : allowedOriginPatterns)
{
if (allowedOrigin.contains("*"))
{
Matcher matcher = createMatcher(origin, allowedOrigin);
if (matcher.matches())
return true;
}
else if (allowedOrigin.equals(origin))
{
if (allowedOrigin.matcher(origin).matches())
return true;
}
}
}
return false;
}
private Matcher createMatcher(String origin, String allowedOrigin)
{
String regex = parseAllowedWildcardOriginToRegex(allowedOrigin);
Pattern pattern = Pattern.compile(regex);
return pattern.matcher(origin);
}
private String parseAllowedWildcardOriginToRegex(String allowedOrigin)
{
String regex = StringUtil.replace(allowedOrigin, ".", "\\.");
@ -505,7 +502,11 @@ public class CrossOriginFilter implements Filter
public void destroy()
{
anyOriginAllowed = false;
anyTimingOriginAllowed = false;
allowedOrigins.clear();
allowedOriginPatterns.clear();
allowedTimingOrigins.clear();
allowedTimingOriginPatterns.clear();
allowedMethods.clear();
allowedHeaders.clear();
preflightMaxAge = 0;