Improving the Grok circular reference check to prevent stack overflow (#1079) (#1087)

This change refactors the circular reference check in the Grok processor class
to use a formal depth-first traversal. It also includes a logic update to
prevent a stack overflow in one scenario and a check for malformed patterns.
This bugfix addresses CVE-2021-22144.

Signed-off-by: Kartik Ganesh <85275476+kartg@users.noreply.github.com>
This commit is contained in:
kartg 2021-08-12 15:47:56 -07:00 committed by GitHub
parent e44e890d88
commit f151bfff24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 41 deletions

View File

@ -53,6 +53,7 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Stack;
import java.util.function.Consumer; import java.util.function.Consumer;
import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableList;
@ -106,11 +107,7 @@ public final class Grok {
this.namedCaptures = namedCaptures; this.namedCaptures = namedCaptures;
this.matcherWatchdog = matcherWatchdog; this.matcherWatchdog = matcherWatchdog;
for (Map.Entry<String, String> entry : patternBank.entrySet()) { validatePatternBank();
String name = entry.getKey();
String pattern = entry.getValue();
forbidCircularReferences(name, new ArrayList<>(), pattern);
}
String expression = toRegex(grokPattern); String expression = toRegex(grokPattern);
byte[] expressionBytes = expression.getBytes(StandardCharsets.UTF_8); byte[] expressionBytes = expression.getBytes(StandardCharsets.UTF_8);
@ -125,46 +122,68 @@ public final class Grok {
} }
/** /**
* Checks whether patterns reference each other in a circular manner and if so fail with an exception * Entry point to recursively validate the pattern bank for circular dependencies and malformed URLs
* via depth-first traversal. This implementation does not include memoization.
*/
private void validatePatternBank() {
for (String patternName : patternBank.keySet()) {
validatePatternBank(patternName, new Stack<>());
}
}
/**
* Checks whether patterns reference each other in a circular manner and, if so, fail with an exception.
* Also checks for malformed pattern definitions and fails with an exception.
* *
* In a pattern, anything between <code>%{</code> and <code>}</code> or <code>:</code> is considered * In a pattern, anything between <code>%{</code> and <code>}</code> or <code>:</code> is considered
* a reference to another named pattern. This method will navigate to all these named patterns and * a reference to another named pattern. This method will navigate to all these named patterns and
* check for a circular reference. * check for a circular reference.
*/ */
private void forbidCircularReferences(String patternName, List<String> path, String pattern) { private void validatePatternBank(String patternName, Stack<String> path) {
if (pattern.contains("%{" + patternName + "}") || pattern.contains("%{" + patternName + ":")) { String pattern = patternBank.get(patternName);
String message; boolean isSelfReference = pattern.contains("%{" + patternName + "}") ||
if (path.isEmpty()) { pattern.contains("%{" + patternName + ":");
message = "circular reference in pattern [" + patternName + "][" + pattern + "]"; if (isSelfReference) {
} else { throwExceptionForCircularReference(patternName, pattern);
message = "circular reference in pattern [" + path.remove(path.size() - 1) + "][" + pattern + } else if (path.contains(patternName)) {
"] back to pattern [" + patternName + "]"; // current pattern name is already in the path, fetch its predecessor
// add rest of the path: String prevPatternName = path.pop();
if (path.isEmpty() == false) { String prevPattern = patternBank.get(prevPatternName);
message += " via patterns [" + String.join("=>", path) + "]"; throwExceptionForCircularReference(prevPatternName, prevPattern, patternName, path);
} }
} path.push(patternName);
throw new IllegalArgumentException(message);
}
for (int i = pattern.indexOf("%{"); i != -1; i = pattern.indexOf("%{", i + 1)) { for (int i = pattern.indexOf("%{"); i != -1; i = pattern.indexOf("%{", i + 1)) {
int begin = i + 2; int begin = i + 2;
int brackedIndex = pattern.indexOf('}', begin); int syntaxEndIndex = pattern.indexOf('}', begin);
int columnIndex = pattern.indexOf(':', begin); if (syntaxEndIndex == -1) {
int end; throw new IllegalArgumentException("Malformed pattern [" + patternName + "][" + pattern +"]");
if (brackedIndex != -1 && columnIndex == -1) {
end = brackedIndex;
} else if (columnIndex != -1 && brackedIndex == -1) {
end = columnIndex;
} else if (brackedIndex != -1 && columnIndex != -1) {
end = Math.min(brackedIndex, columnIndex);
} else {
throw new IllegalArgumentException("pattern [" + pattern + "] has circular references to other pattern definitions");
} }
String otherPatternName = pattern.substring(begin, end); int semanticNameIndex = pattern.indexOf(':', begin);
path.add(otherPatternName); int end = syntaxEndIndex;
forbidCircularReferences(patternName, path, patternBank.get(otherPatternName)); if (semanticNameIndex != -1) {
end = Math.min(syntaxEndIndex, semanticNameIndex);
} }
String dependsOnPattern = pattern.substring(begin, end);
validatePatternBank(dependsOnPattern, path);
}
path.pop();
}
private static void throwExceptionForCircularReference(String patternName, String pattern) {
throwExceptionForCircularReference(patternName, pattern, null, null);
}
private static void throwExceptionForCircularReference(String patternName, String pattern, String originPatterName,
Stack<String> path) {
StringBuilder message = new StringBuilder("circular reference in pattern [");
message.append(patternName).append("][").append(pattern).append("]");
if (originPatterName != null) {
message.append(" back to pattern [").append(originPatterName).append("]");
}
if (path != null && path.size() > 1) {
message.append(" via patterns [").append(String.join("=>", path)).append("]");
}
throw new IllegalArgumentException(message.toString());
} }
private String groupMatch(String name, Region region, String pattern) { private String groupMatch(String name, Region region, String pattern) {

View File

@ -51,16 +51,16 @@ import java.util.function.Function;
import java.util.function.IntConsumer; import java.util.function.IntConsumer;
import java.util.function.LongConsumer; import java.util.function.LongConsumer;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.opensearch.grok.GrokCaptureType.BOOLEAN; import static org.opensearch.grok.GrokCaptureType.BOOLEAN;
import static org.opensearch.grok.GrokCaptureType.DOUBLE; import static org.opensearch.grok.GrokCaptureType.DOUBLE;
import static org.opensearch.grok.GrokCaptureType.FLOAT; import static org.opensearch.grok.GrokCaptureType.FLOAT;
import static org.opensearch.grok.GrokCaptureType.INTEGER; import static org.opensearch.grok.GrokCaptureType.INTEGER;
import static org.opensearch.grok.GrokCaptureType.LONG; import static org.opensearch.grok.GrokCaptureType.LONG;
import static org.opensearch.grok.GrokCaptureType.STRING; import static org.opensearch.grok.GrokCaptureType.STRING;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class GrokTests extends OpenSearchTestCase { public class GrokTests extends OpenSearchTestCase {
@ -344,7 +344,17 @@ public class GrokTests extends OpenSearchTestCase {
String pattern = "%{NAME1}"; String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn); new Grok(bank, pattern, false, logger::warn);
}); });
assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME2]", assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME1=>NAME2]",
e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new TreeMap<>();
bank.put("NAME1", "!!!%{NAME2}!!!");
bank.put("NAME2", "!!!%{NAME2}!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("circular reference in pattern [NAME2][!!!%{NAME2}!!!]",
e.getMessage()); e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> { e = expectThrows(IllegalArgumentException.class, () -> {
@ -358,7 +368,25 @@ public class GrokTests extends OpenSearchTestCase {
new Grok(bank, pattern, false, logger::warn ); new Grok(bank, pattern, false, logger::warn );
}); });
assertEquals("circular reference in pattern [NAME5][!!!%{NAME1}!!!] back to pattern [NAME1] " + assertEquals("circular reference in pattern [NAME5][!!!%{NAME1}!!!] back to pattern [NAME1] " +
"via patterns [NAME2=>NAME3=>NAME4]", e.getMessage()); "via patterns [NAME1=>NAME2=>NAME3=>NAME4]", e.getMessage());
}
public void testMalformedPattern() {
Exception e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new HashMap<>();
bank.put("NAME1", "!!!%{NAME2:!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("Malformed pattern [NAME1][!!!%{NAME2:!!!]", e.getMessage());
e = expectThrows(IllegalArgumentException.class, () -> {
Map<String, String> bank = new HashMap<>();
bank.put("NAME1", "!!!%{NAME2!!!");
String pattern = "%{NAME1}";
new Grok(bank, pattern, false, logger::warn);
});
assertEquals("Malformed pattern [NAME1][!!!%{NAME2!!!]", e.getMessage());
} }
public void testBooleanCaptures() { public void testBooleanCaptures() {