Improving the Grok circular reference check to prevent stack overflow (#1079) (#1087)

This change refactors the circular reference check in the Grok processor class
to use a formal depth-first traversal. It also includes a logic update to
prevent a stack overflow in one scenario and a check for malformed patterns.
This bugfix addresses CVE-2021-22144.

Signed-off-by: Kartik Ganesh <85275476+kartg@users.noreply.github.com>
This commit is contained in:
kartg 2021-08-12 15:47:56 -07:00 committed by GitHub
parent e44e890d88
commit f151bfff24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 41 deletions

View File

@ -53,6 +53,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Stack;
import java.util.function.Consumer;
import static java.util.Collections.unmodifiableList;
@ -106,11 +107,7 @@ public final class Grok {
this.namedCaptures = namedCaptures;
this.matcherWatchdog = matcherWatchdog;
for (Map.Entry<String, String> entry : patternBank.entrySet()) {
String name = entry.getKey();
String pattern = entry.getValue();
forbidCircularReferences(name, new ArrayList<>(), pattern);
}
validatePatternBank();
String expression = toRegex(grokPattern);
byte[] expressionBytes = expression.getBytes(StandardCharsets.UTF_8);
@ -125,46 +122,68 @@ public final class Grok {
}
/**
* Checks whether patterns reference each other in a circular manner and if so fail with an exception
* Entry point to recursively validate the pattern bank for circular dependencies and malformed URLs
* via depth-first traversal. This implementation does not include memoization.
*/
private void validatePatternBank() {
for (String patternName : patternBank.keySet()) {
validatePatternBank(patternName, new Stack<>());
}
}
/**
* Checks whether patterns reference each other in a circular manner and, if so, fail with an exception.
* Also checks for malformed pattern definitions and fails with an exception.
*
* In a pattern, anything between <code>%{</code> and <code>}</code> or <code>:</code> is considered
* a reference to another named pattern. This method will navigate to all these named patterns and
* check for a circular reference.
*/
private void forbidCircularReferences(String patternName, List<String> path, String pattern) {
if (pattern.contains("%{" + patternName + "}") || pattern.contains("%{" + patternName + ":")) {
String message;
if (path.isEmpty()) {
message = "circular reference in pattern [" + patternName + "][" + pattern + "]";
} else {
message = "circular reference in pattern [" + path.remove(path.size() - 1) + "][" + pattern +
"] back to pattern [" + patternName + "]";
// add rest of the path:
if (path.isEmpty() == false) {
message += " via patterns [" + String.join("=>", path) + "]";
}
}
throw new IllegalArgumentException(message);
private void validatePatternBank(String patternName, Stack<String> path) {
String pattern = patternBank.get(patternName);
boolean isSelfReference = pattern.contains("%{" + patternName + "}") ||
pattern.contains("%{" + patternName + ":");
if (isSelfReference) {
throwExceptionForCircularReference(patternName, pattern);
} else if (path.contains(patternName)) {
// current pattern name is already in the path, fetch its predecessor
String prevPatternName = path.pop();
String prevPattern = patternBank.get(prevPatternName);
throwExceptionForCircularReference(prevPatternName, prevPattern, patternName, path);
}
path.push(patternName);
for (int i = pattern.indexOf("%{"); i != -1; i = pattern.indexOf("%{", i + 1)) {
int begin = i + 2;
int brackedIndex = pattern.indexOf('}', begin);
int columnIndex = pattern.indexOf(':', begin);
int end;
if (brackedIndex != -1 && columnIndex == -1) {
end = brackedIndex;
} else if (columnIndex != -1 && brackedIndex == -1) {
end = columnIndex;
} else if (brackedIndex != -1 && columnIndex != -1) {
end = Math.min(brackedIndex, columnIndex);
} else {
throw new IllegalArgumentException("pattern [" + pattern + "] has circular references to other pattern definitions");
int syntaxEndIndex = pattern.indexOf('}', begin);
if (syntaxEndIndex == -1) {
throw new IllegalArgumentException("Malformed pattern [" + patternName + "][" + pattern +"]");
}
String otherPatternName = pattern.substring(begin, end);
path.add(otherPatternName);
forbidCircularReferences(patternName, path, patternBank.get(otherPatternName));
int semanticNameIndex = pattern.indexOf(':', begin);
int end = syntaxEndIndex;
if (semanticNameIndex != -1) {
end = Math.min(syntaxEndIndex, semanticNameIndex);
}
String dependsOnPattern = pattern.substring(begin, end);
validatePatternBank(dependsOnPattern, path);
}
path.pop();
}
private static void throwExceptionForCircularReference(String patternName, String pattern) {
throwExceptionForCircularReference(patternName, pattern, null, null);
}
private static void throwExceptionForCircularReference(String patternName, String pattern, String originPatterName,
Stack<String> path) {
StringBuilder message = new StringBuilder("circular reference in pattern [");
message.append(patternName).append("][").append(pattern).append("]");
if (originPatterName != null) {
message.append(" back to pattern [").append(originPatterName).append("]");
}
if (path != null && path.size() > 1) {
message.append(" via patterns [").append(String.join("=>", path)).append("]");
}
throw new IllegalArgumentException(message.toString());
}
private String groupMatch(String name, Region region, String pattern) {

View File

@ -51,16 +51,16 @@ import java.util.function.Function;
import java.util.function.IntConsumer;
import java.util.function.LongConsumer;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.opensearch.grok.GrokCaptureType.BOOLEAN;
import static org.opensearch.grok.GrokCaptureType.DOUBLE;
import static org.opensearch.grok.GrokCaptureType.FLOAT;
import static org.opensearch.grok.GrokCaptureType.INTEGER;
import static org.opensearch.grok.GrokCaptureType.LONG;
import static org.opensearch.grok.GrokCaptureType.STRING;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class GrokTests extends OpenSearchTestCase {
@ -344,7 +344,17 @@ public class GrokTests extends OpenSearchTestCase {
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME2]",
assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME1=>NAME2]",
e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new TreeMap<>();
bank.put("NAME1", "!!!%{NAME2}!!!");
bank.put("NAME2", "!!!%{NAME2}!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("circular reference in pattern [NAME2][!!!%{NAME2}!!!]",
e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> {
@ -358,7 +368,25 @@ public class GrokTests extends OpenSearchTestCase {
new Grok(bank, pattern, false, logger::warn );
});
assertEquals("circular reference in pattern [NAME5][!!!%{NAME1}!!!] back to pattern [NAME1] " +
"via patterns [NAME2=>NAME3=>NAME4]", e.getMessage());
"via patterns [NAME1=>NAME2=>NAME3=>NAME4]", e.getMessage());
}
public void testMalformedPattern() {
Exception e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new HashMap<>();
bank.put("NAME1", "!!!%{NAME2:!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("Malformed pattern [NAME1][!!!%{NAME2:!!!]", e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new HashMap<>();
bank.put("NAME1", "!!!%{NAME2!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("Malformed pattern [NAME1][!!!%{NAME2!!!]", e.getMessage());
}
public void testBooleanCaptures() {