LUCENE-9338: Clean up type safety in SimpleBindings (#1444)

Replaces SimpleBindings' Map<String, Object> with a map of
Function<Bindings, DoubleValuesSource> to improve type safety, and
reworks cycle detection and validation to avoid catching 
StackOverflowException
This commit is contained in:
Alan Woodward 2020-04-24 10:23:50 +01:00 committed by GitHub
parent 83018deef7
commit ed3caab2d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 30 deletions

View File

@ -221,6 +221,9 @@ Other
* LUCENE-9191: Make LineFileDocs's random seeking more efficient, making tests using LineFileDocs faster (Robert Muir,
Mike McCandless)
* LUCENE-9338: Refactors SimpleBindings to improve type safety and cycle detection (Alan Woodward,
Adrien Grand)
======================= Lucene 8.5.1 =======================
Bug Fixes

View File

@ -31,9 +31,8 @@ import org.apache.lucene.search.IndexSearcher;
/**
* A {@link DoubleValuesSource} which evaluates a {@link Expression} given the context of an {@link Bindings}.
*/
@SuppressWarnings({"rawtypes", "unchecked"})
final class ExpressionValueSource extends DoubleValuesSource {
final DoubleValuesSource variables[];
final DoubleValuesSource[] variables;
final Expression expression;
final boolean needsScores;

View File

@ -18,7 +18,10 @@ package org.apache.lucene.expressions;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.SortField;
@ -44,7 +47,8 @@ import org.apache.lucene.search.SortField;
* @lucene.experimental
*/
public final class SimpleBindings extends Bindings {
final Map<String,Object> map = new HashMap<>();
final Map<String, Function<Bindings, DoubleValuesSource>> map = new HashMap<>();
/** Creates a new empty Bindings */
public SimpleBindings() {}
@ -56,13 +60,13 @@ public final class SimpleBindings extends Bindings {
* FieldCache, the document's score, etc.
*/
public void add(SortField sortField) {
map.put(sortField.getField(), sortField);
map.put(sortField.getField(), bindings -> fromSortField(sortField));
}
/**
* Bind a {@link DoubleValuesSource} directly to the given name.
*/
public void add(String name, DoubleValuesSource source) { map.put(name, source); }
public void add(String name, DoubleValuesSource source) { map.put(name, bindings -> source); }
/**
* Adds an Expression to the bindings.
@ -70,20 +74,10 @@ public final class SimpleBindings extends Bindings {
* This can be used to reference expressions from other expressions.
*/
public void add(String name, Expression expression) {
map.put(name, expression);
map.put(name, expression::getDoubleValuesSource);
}
@Override
public DoubleValuesSource getDoubleValuesSource(String name) {
Object o = map.get(name);
if (o == null) {
throw new IllegalArgumentException("Invalid reference '" + name + "'");
} else if (o instanceof Expression) {
return ((Expression)o).getDoubleValuesSource(this);
} else if (o instanceof DoubleValuesSource) {
return ((DoubleValuesSource) o);
}
SortField field = (SortField) o;
private DoubleValuesSource fromSortField(SortField field) {
switch(field.getType()) {
case INT:
return DoubleValuesSource.fromIntField(field.getField());
@ -96,24 +90,51 @@ public final class SimpleBindings extends Bindings {
case SCORE:
return DoubleValuesSource.SCORES;
default:
throw new UnsupportedOperationException();
throw new UnsupportedOperationException();
}
}
/**
* Traverses the graph of bindings, checking there are no cycles or missing references
* @throws IllegalArgumentException if the bindings is inconsistent
@Override
public DoubleValuesSource getDoubleValuesSource(String name) {
if (map.containsKey(name) == false) {
throw new IllegalArgumentException("Invalid reference '" + name + "'");
}
return map.get(name).apply(this);
}
/**
* Traverses the graph of bindings, checking there are no cycles or missing references
* @throws IllegalArgumentException if the bindings is inconsistent
*/
public void validate() {
for (Object o : map.values()) {
if (o instanceof Expression) {
Expression expr = (Expression) o;
try {
expr.getDoubleValuesSource(this);
} catch (StackOverflowError e) {
throw new IllegalArgumentException("Recursion Error: Cycle detected originating in (" + expr.sourceText + ")");
}
for (Map.Entry<String, Function<Bindings, DoubleValuesSource>> origin : map.entrySet()) {
origin.getValue().apply(new CycleDetectionBindings(origin.getKey()));
}
}
private class CycleDetectionBindings extends Bindings {
private final Set<String> seenFields = new LinkedHashSet<>();
CycleDetectionBindings(String current) {
seenFields.add(current);
}
CycleDetectionBindings(Set<String> parents, String current) {
seenFields.addAll(parents);
seenFields.add(current);
}
@Override
public DoubleValuesSource getDoubleValuesSource(String name) {
if (seenFields.contains(name)) {
throw new IllegalArgumentException("Recursion error: Cycle detected " + seenFields + "->" + name);
}
if (map.containsKey(name) == false) {
throw new IllegalArgumentException("Invalid reference '" + name + "'");
}
return map.get(name).apply(new CycleDetectionBindings(seenFields, name));
}
}
}