From ed3caab2d86b69ec4b3ed8e787827c0931b43d1b Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Fri, 24 Apr 2020 10:23:50 +0100 Subject: [PATCH] LUCENE-9338: Clean up type safety in SimpleBindings (#1444) Replaces SimpleBindings' Map with a map of Function to improve type safety, and reworks cycle detection and validation to avoid catching StackOverflowException --- lucene/CHANGES.txt | 3 + .../expressions/ExpressionValueSource.java | 3 +- .../lucene/expressions/SimpleBindings.java | 77 ++++++++++++------- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 71d51b087ae..db5d42a4410 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -221,6 +221,9 @@ Other * LUCENE-9191: Make LineFileDocs's random seeking more efficient, making tests using LineFileDocs faster (Robert Muir, Mike McCandless) +* LUCENE-9338: Refactors SimpleBindings to improve type safety and cycle detection (Alan Woodward, + Adrien Grand) + ======================= Lucene 8.5.1 ======================= Bug Fixes diff --git a/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionValueSource.java b/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionValueSource.java index 8cec5cc1533..f4fa894af8a 100644 --- a/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionValueSource.java +++ b/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionValueSource.java @@ -31,9 +31,8 @@ import org.apache.lucene.search.IndexSearcher; /** * A {@link DoubleValuesSource} which evaluates a {@link Expression} given the context of an {@link Bindings}. */ -@SuppressWarnings({"rawtypes", "unchecked"}) final class ExpressionValueSource extends DoubleValuesSource { - final DoubleValuesSource variables[]; + final DoubleValuesSource[] variables; final Expression expression; final boolean needsScores; diff --git a/lucene/expressions/src/java/org/apache/lucene/expressions/SimpleBindings.java b/lucene/expressions/src/java/org/apache/lucene/expressions/SimpleBindings.java index 627605532fe..3b2362ba87d 100644 --- a/lucene/expressions/src/java/org/apache/lucene/expressions/SimpleBindings.java +++ b/lucene/expressions/src/java/org/apache/lucene/expressions/SimpleBindings.java @@ -18,7 +18,10 @@ package org.apache.lucene.expressions; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; +import java.util.function.Function; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.SortField; @@ -44,7 +47,8 @@ import org.apache.lucene.search.SortField; * @lucene.experimental */ public final class SimpleBindings extends Bindings { - final Map map = new HashMap<>(); + + final Map> map = new HashMap<>(); /** Creates a new empty Bindings */ public SimpleBindings() {} @@ -56,13 +60,13 @@ public final class SimpleBindings extends Bindings { * FieldCache, the document's score, etc. */ public void add(SortField sortField) { - map.put(sortField.getField(), sortField); + map.put(sortField.getField(), bindings -> fromSortField(sortField)); } /** * Bind a {@link DoubleValuesSource} directly to the given name. */ - public void add(String name, DoubleValuesSource source) { map.put(name, source); } + public void add(String name, DoubleValuesSource source) { map.put(name, bindings -> source); } /** * Adds an Expression to the bindings. @@ -70,20 +74,10 @@ public final class SimpleBindings extends Bindings { * This can be used to reference expressions from other expressions. */ public void add(String name, Expression expression) { - map.put(name, expression); + map.put(name, expression::getDoubleValuesSource); } - - @Override - public DoubleValuesSource getDoubleValuesSource(String name) { - Object o = map.get(name); - if (o == null) { - throw new IllegalArgumentException("Invalid reference '" + name + "'"); - } else if (o instanceof Expression) { - return ((Expression)o).getDoubleValuesSource(this); - } else if (o instanceof DoubleValuesSource) { - return ((DoubleValuesSource) o); - } - SortField field = (SortField) o; + + private DoubleValuesSource fromSortField(SortField field) { switch(field.getType()) { case INT: return DoubleValuesSource.fromIntField(field.getField()); @@ -96,24 +90,51 @@ public final class SimpleBindings extends Bindings { case SCORE: return DoubleValuesSource.SCORES; default: - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException(); } } - /** - * Traverses the graph of bindings, checking there are no cycles or missing references - * @throws IllegalArgumentException if the bindings is inconsistent + @Override + public DoubleValuesSource getDoubleValuesSource(String name) { + if (map.containsKey(name) == false) { + throw new IllegalArgumentException("Invalid reference '" + name + "'"); + } + return map.get(name).apply(this); + } + + /** + * Traverses the graph of bindings, checking there are no cycles or missing references + * @throws IllegalArgumentException if the bindings is inconsistent */ public void validate() { - for (Object o : map.values()) { - if (o instanceof Expression) { - Expression expr = (Expression) o; - try { - expr.getDoubleValuesSource(this); - } catch (StackOverflowError e) { - throw new IllegalArgumentException("Recursion Error: Cycle detected originating in (" + expr.sourceText + ")"); - } + for (Map.Entry> origin : map.entrySet()) { + origin.getValue().apply(new CycleDetectionBindings(origin.getKey())); + } + } + + private class CycleDetectionBindings extends Bindings { + + private final Set seenFields = new LinkedHashSet<>(); + + CycleDetectionBindings(String current) { + seenFields.add(current); + } + + CycleDetectionBindings(Set parents, String current) { + seenFields.addAll(parents); + seenFields.add(current); + } + + @Override + public DoubleValuesSource getDoubleValuesSource(String name) { + if (seenFields.contains(name)) { + throw new IllegalArgumentException("Recursion error: Cycle detected " + seenFields + "->" + name); } + if (map.containsKey(name) == false) { + throw new IllegalArgumentException("Invalid reference '" + name + "'"); + } + return map.get(name).apply(new CycleDetectionBindings(seenFields, name)); } } } +